Commit 1d085d98 authored by song's avatar song
Browse files

feat: 完善 Antigravity 多平台网关支持,修复 Gemini handler 分流逻辑

parent 6648e650
...@@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService() timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
...@@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService) apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
......
...@@ -21,27 +21,30 @@ import ( ...@@ -21,27 +21,30 @@ import (
// GatewayHandler handles API gateway requests // GatewayHandler handles API gateway requests
type GatewayHandler struct { type GatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService geminiCompatService *service.GeminiMessagesCompatService
userService *service.UserService antigravityGatewayService *service.AntigravityGatewayService
billingCacheService *service.BillingCacheService userService *service.UserService
concurrencyHelper *ConcurrencyHelper billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
} }
// NewGatewayHandler creates a new GatewayHandler // NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler( func NewGatewayHandler(
gatewayService *service.GatewayService, gatewayService *service.GatewayService,
geminiCompatService *service.GeminiMessagesCompatService, geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService, userService *service.UserService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
) *GatewayHandler { ) *GatewayHandler {
return &GatewayHandler{ return &GatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
geminiCompatService: geminiCompatService, geminiCompatService: geminiCompatService,
userService: userService, antigravityGatewayService: antigravityGatewayService,
billingCacheService: billingCacheService, userService: userService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
} }
} }
...@@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 转发请求 // 转发请求 - 根据账号平台分流
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body) var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
...@@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 转发请求 // 转发请求 - 根据账号平台分流
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
......
...@@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { ...@@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil { if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型列表
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return return
} }
...@@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { ...@@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil { if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型信息
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return return
} }
...@@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return return
} }
// 5) forward (writes response to client) // 5) forward (根据平台分流)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
} else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
......
//go:build unit
package handler
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) {
tests := []struct {
name string
platform string
expectedService string
description string
}{
{
name: "Gemini平台使用ForwardNative",
platform: service.PlatformGemini,
expectedService: "GeminiMessagesCompatService.ForwardNative",
description: "Gemini OAuth 账户直接调用 Google API",
},
{
name: "Antigravity平台使用ForwardGemini",
platform: service.PlatformAntigravity,
expectedService: "AntigravityGatewayService.ForwardGemini",
description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
var routedService string
if tt.platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
require.Equal(t, tt.expectedService, routedService,
"平台 %s 应该路由到 %s: %s",
tt.platform, tt.expectedService, tt.description)
})
}
}
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) {
tests := []struct {
name string
hasGeminiAccount bool
hasAntigravity bool
expectedBehavior string
}{
{
name: "有Gemini账户-调用ForwardAIStudioGET",
hasGeminiAccount: true,
hasAntigravity: false,
expectedBehavior: "forward_to_upstream",
},
{
name: "无Gemini有Antigravity-返回静态列表",
hasGeminiAccount: false,
hasAntigravity: true,
expectedBehavior: "static_fallback",
},
{
name: "无任何账户-返回503",
hasGeminiAccount: false,
hasAntigravity: false,
expectedBehavior: "service_unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
var behavior string
if tt.hasGeminiAccount {
behavior = "forward_to_upstream"
} else if tt.hasAntigravity {
behavior = "static_fallback"
} else {
behavior = "service_unavailable"
}
require.Equal(t, tt.expectedBehavior, behavior)
})
}
}
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
tests := []struct {
name string
hasGeminiAccount bool
hasAntigravity bool
expectedBehavior string
}{
{
name: "有Gemini账户-调用ForwardAIStudioGET",
hasGeminiAccount: true,
hasAntigravity: false,
expectedBehavior: "forward_to_upstream",
},
{
name: "无Gemini有Antigravity-返回静态模型信息",
hasGeminiAccount: false,
hasAntigravity: true,
expectedBehavior: "static_model_info",
},
{
name: "无任何账户-返回503",
hasGeminiAccount: false,
hasAntigravity: false,
expectedBehavior: "service_unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
var behavior string
if tt.hasGeminiAccount {
behavior = "forward_to_upstream"
} else if tt.hasAntigravity {
behavior = "static_model_info"
} else {
behavior = "service_unavailable"
}
require.Equal(t, tt.expectedBehavior, behavior)
})
}
}
...@@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont ...@@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
return outAccounts, nil return outAccounts, nil
} }
func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform IN ?", platforms).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform IN ?", platforms).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
......
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
// 验证账户选择和分流逻辑在真实数据库环境下的行为
type GatewayRoutingSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
accountRepo *accountRepository
}
func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.accountRepo = NewAccountRepository(s.db).(*accountRepository)
}
func TestGatewayRoutingSuite(t *testing.T) {
suite.Run(t, new(GatewayRoutingSuite))
}
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
// 创建各平台账户
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "gemini-oauth",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 1,
})
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "antigravity-oauth",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 2,
Credentials: datatypes.JSONMap{
"access_token": "test-token",
"refresh_token": "test-refresh",
"project_id": "test-project",
},
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "anthropic-oauth",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 0,
})
// 查询 gemini + antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
// 验证返回的账户平台
platforms := make(map[string]bool)
for _, acc := range accounts {
platforms[acc.Platform] = true
}
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
// 验证账户 ID 匹配
ids := make(map[int64]bool)
for _, acc := range accounts {
ids[acc.ID] = true
}
s.Require().True(ids[geminiAcc.ID])
s.Require().True(ids[antigravityAcc.ID])
}
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
// 创建 gemini 分组
group := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "gemini-group",
Platform: service.PlatformGemini,
Status: service.StatusActive,
})
// 创建账户
boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "unbound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只绑定一个账户到分组
mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1)
// 查询分组内的账户
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
s.Require().Equal(boundAcc.ID, accounts[0].ID)
// 确认未绑定的账户不在结果中
for _, acc := range accounts {
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
}
}
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
// 创建多种平台账户
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "gemini-1",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravity := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "antigravity-1",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只查询 antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(antigravity.ID, accounts[0].ID)
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
}
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
// 创建可调度账户
activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "active-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "inactive-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
})
s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error)
// 创建错误状态账户
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "error-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusError,
Schedulable: true,
})
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
s.Require().Equal(activeAcc.ID, accounts[0].ID)
}
// TestPlatformRoutingDecision 验证平台路由决策
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
// 创建两种平台的账户
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "gemini-route-test",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "antigravity-route-test",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
tests := []struct {
name string
accountID int64
expectedService string
}{
{
name: "Gemini账户路由到ForwardNative",
accountID: geminiAcc.ID,
expectedService: "GeminiMessagesCompatService.ForwardNative",
},
{
name: "Antigravity账户路由到ForwardGemini",
accountID: antigravityAcc.ID,
expectedService: "AntigravityGatewayService.ForwardGemini",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 从数据库获取账户
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
s.Require().NoError(err)
// 模拟 Handler 层的路由决策
var routedService string
if account.Platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
s.Require().Equal(tt.expectedService, routedService)
})
}
}
...@@ -38,6 +38,8 @@ type AccountRepository interface { ...@@ -38,6 +38,8 @@ type AccountRepository interface {
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
......
This diff is collapsed.
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsAntigravityModelSupported(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// 直接支持的模型
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
// 可映射的模型
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// Claude 前缀兜底
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
{"Claude前缀 - claude-future-version", "claude-future-version", true},
// 不支持的模型
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false},
{"不支持 - mistral-7b", "mistral-7b", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsAntigravityModelSupported(tt.model)
require.Equal(t, tt.expected, got, "model: %s", tt.model)
})
}
}
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
expected: "custom-model",
},
{
name: "账户映射覆盖系统映射",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 系统默认映射
{
name: "系统映射 - claude-3-5-sonnet-20241022",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-5-sonnet-20240620",
requestedModel: "claude-3-5-sonnet-20240620",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-opus-4",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-opus-4-5-20251101",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-haiku-4",
requestedModel: "claude-haiku-4",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-haiku-20240307",
requestedModel: "claude-3-haiku-20240307",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
// 3. Gemini 透传
{
name: "Gemini透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-2.5-flash",
},
{
name: "Gemini透传 - gemini-1.5-pro",
requestedModel: "gemini-1.5-pro",
accountMapping: nil,
expected: "gemini-1.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "gemini-future-model",
},
// 4. 直接支持的模型
{
name: "直接支持 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "直接支持 - claude-opus-4-5-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "直接支持 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
// 5. 默认值 fallback(未知 claude 模型)
{
name: "默认值 - claude-unknown",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认值 - claude-3-opus-20240229",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
}
if tt.accountMapping != nil {
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
mappingAny := make(map[string]any)
for k, v := range tt.accountMapping {
mappingAny[k] = v
}
account.Credentials = map[string]any{
"model_mapping": mappingAny,
}
}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
})
}
}
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
expected string
}{
// 空字符串回退到默认值
{"空字符串", "", "claude-sonnet-4-5"},
// 非 claude/gemini 前缀回退到默认值
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: PlatformAntigravity}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got)
})
}
}
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
model string
expected bool
}{
// 直接支持
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射
{"可映射 - claude-opus-4", "claude-opus-4", true},
// 前缀透传
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
// 不支持
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.IsModelSupported(tt.model)
require.Equal(t, tt.expected, got)
})
}
}
package service
import (
"context"
"errors"
"log"
"strconv"
"strings"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenProvider(
accountRepo AccountRepository,
tokenCache AntigravityTokenCache,
antigravityOAuthService *AntigravityOAuthService,
) *AntigravityTokenProvider {
return &AntigravityTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
cacheKey := antigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2. 如果即将过期则刷新
expiresAt := parseAntigravityExpiresAt(account)
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = parseAntigravityExpiresAt(account)
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
}
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = parseAntigravityExpiresAt(account)
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func antigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}
func parseAntigravityExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}
package service
import (
"context"
"strconv"
"time"
)
// AntigravityTokenRefresher 实现 TokenRefresher 接口
type AntigravityTokenRefresher struct {
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
return &AntigravityTokenRefresher{
antigravityOAuthService: antigravityOAuthService,
}
}
// CanRefresh 检查是否可以刷新此账户
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查账户是否需要刷新
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < refreshWindow
}
// Refresh 执行 token 刷新
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
return newCredentials, nil
}
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForMultiplatform 多平台测试用的 mock
type mockAccountRepoForMultiplatform struct {
accounts []Account
accountsByID map[int64]*Account
listPlatformsFunc func(ctx context.Context, platforms []string) ([]Account, error)
}
func (m *mockAccountRepoForMultiplatform) GetByID(ctx context.Context, id int64) (*Account, error) {
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
if m.listPlatformsFunc != nil {
return m.listPlatformsFunc(ctx, platforms)
}
// 过滤符合平台的账户
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
platformSet[p] = true
}
for _, acc := range m.accounts {
if platformSet[acc.Platform] && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
// Stub methods to implement AccountRepository interface
func (m *mockAccountRepoForMultiplatform) Create(ctx context.Context, account *Account) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) Update(ctx context.Context, account *Account) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForMultiplatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListActive(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) UpdateLastUsed(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulable(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
return 0, nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForMultiplatform)(nil)
// mockGatewayCacheForMultiplatform 多平台测试用的 cache mock
type mockGatewayCacheForMultiplatform struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForMultiplatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForMultiplatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
m.sessionBindings[sessionHash] = accountID
return nil
}
func (m *mockGatewayCacheForMultiplatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
return nil
}
func ptr[T any](v T) *T {
return &v
}
func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAnthropic(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户")
}
func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
}
func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) {
ctx := context.Background()
now := time.Now()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户(Antigravity)")
}
func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_DiffPriority(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级更高的账户(Antigravity, priority=1)")
}
func TestGatewayService_SelectAccountForModelWithExclusions_ModelNotSupported(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
// Anthropic 账户配置了模型映射,只支持 other-model
// 注意:model_mapping 需要是 map[string]any 格式
{
ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"other-model": "x"}},
},
// Antigravity 账户支持所有 claude 模型
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "Anthropic 不支持该模型,应选择 Antigravity")
}
func TestGatewayService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
}
func TestGatewayService_SelectAccountForModelWithExclusions_AllExcluded(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
excludedIDs := map[int64]struct{}{1: {}, 2: {}}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity})
require.Error(t, err)
require.Nil(t, acc)
}
func TestGatewayService_SelectAccountForModelWithExclusions_Schedulability(t *testing.T) {
ctx := context.Background()
now := time.Now()
tests := []struct {
name string
accounts []Account
expectedID int64
}{
{
name: "过载账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "限流账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, RateLimitResetAt: ptr(now.Add(1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "非active账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: "error", Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "schedulable=false被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "过期的过载账户可调度",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: tt.accounts,
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, tt.expectedID, acc.ID)
})
}
}
func TestGatewayService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
ctx := context.Background()
t.Run("粘性会话命中", func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
})
t.Run("粘性会话账户被排除-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
excludedIDs := map[int64]struct{}{1: {}}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户被排除,应选择其他账户")
})
t.Run("粘性会话账户不可调度-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: "error", Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户不可调度,应选择其他账户")
})
}
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "Anthropic平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformAnthropic},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Anthropic平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{
name: "Anthropic平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}
...@@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 使用多平台账户选择,包含 anthropic 和 antigravity 平台
platforms := []string{PlatformAnthropic, PlatformAntigravity}
return s.selectAccountForModelWithPlatforms(ctx, groupID, sessionHash, requestedModel, excludedIDs, platforms)
}
// selectAccountForModelWithPlatforms 选择多平台账户
func (s *GatewayService) selectAccountForModelWithPlatforms(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platforms []string) (*Account, error) {
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
...@@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
// 同时检查模型支持 // 同时检查模型支持(根据平台类型分别处理)
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
// 续期粘性会话 // 续期粘性会话
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
...@@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
} }
} }
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) // 2. 获取可调度账号列表(排除限流和过载的账号,支持多平台)
var accounts []Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
...@@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
// 检查模型支持 // 检查模型支持(根据平台类型分别处理)
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
if selected == nil { if selected == nil {
...@@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return selected, nil return selected, nil
} }
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
}
// GetAccessToken 获取账号凭证 // GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
......
...@@ -33,11 +33,12 @@ const ( ...@@ -33,11 +33,12 @@ const (
) )
type GeminiMessagesCompatService struct { type GeminiMessagesCompatService struct {
accountRepo AccountRepository accountRepo AccountRepository
cache GatewayCache cache GatewayCache
tokenProvider *GeminiTokenProvider tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService rateLimitService *RateLimitService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
} }
func NewGeminiMessagesCompatService( func NewGeminiMessagesCompatService(
...@@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService( ...@@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService(
tokenProvider *GeminiTokenProvider, tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
) *GeminiMessagesCompatService { ) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{ return &GeminiMessagesCompatService{
accountRepo: accountRepo, accountRepo: accountRepo,
cache: cache, cache: cache,
tokenProvider: tokenProvider, tokenProvider: tokenProvider,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
} }
} }
...@@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, ...@@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
cacheKey := "gemini:" + sessionHash cacheKey := "gemini:" + sessionHash
platforms := []string{PlatformGemini, PlatformAntigravity}
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // 支持 gemini 和 antigravity 平台的粘性会话
if err == nil && account.IsSchedulable() && (account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil return account, nil
} }
...@@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
} }
} }
// 同时查询 gemini 和 antigravity 平台的可调度账户
var accounts []Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
...@@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { // 根据平台类型分别检查模型支持
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
if selected == nil { if selected == nil {
...@@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if selected == nil { if selected == nil {
if requestedModel != "" { if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel) return nil, fmt.Errorf("no available Gemini/Antigravity accounts supporting model: %s", requestedModel)
} }
return nil, errors.New("no available Gemini accounts") return nil, errors.New("no available Gemini/Antigravity accounts")
} }
if sessionHash != "" { if sessionHash != "" {
...@@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return selected, nil return selected, nil
} }
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
}
return account.IsModelSupported(requestedModel)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
return s.antigravityGatewayService
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
}
if err != nil {
return false, err
}
return len(accounts) > 0, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against // SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models). // generativelanguage.googleapis.com (e.g. GET /v1beta/models).
// //
......
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct {
accounts []Account
accountsByID map[int64]*Account
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool)
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.accounts {
if platformSet[acc.Platform] && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
// Stub methods to implement AccountRepository interface
func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) { return nil, nil }
func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
// 测试时不区分 groupID,直接按 platform 过滤
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
return 0, nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
m.sessionBindings[sessionHash] = accountID
return nil
}
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
return nil
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyGemini(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludesAnthropic(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 3, Platform: PlatformAntigravity, Priority: 3, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// Anthropic 不在 [gemini, antigravity] 平台列表中,应被过滤
require.Equal(t, int64(2), acc.ID, "Anthropic 平台应被排除,选择 Gemini")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) {
ctx := context.Background()
now := time.Now()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户(Antigravity)")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户")
require.Equal(t, AccountTypeOAuth, acc.Type)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred_MixedPlatforms(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "跨平台时,同样优先选择 OAuth 账户")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available Gemini/Antigravity accounts")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
ctx := context.Background()
t.Run("粘性会话命中-使用gemini前缀缓存键", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
// 注意:缓存键使用 "gemini:" 前缀
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
})
t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
// 缓存键没有 "gemini:" 前缀,不应命中
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择 Antigravity")
})
t.Run("粘性会话Anthropic账户-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// 粘性会话绑定的是 Anthropic 账户,不在 Gemini 平台列表中,应降级选择
require.Equal(t, int64(2), acc.ID, "粘性会话账户是 Anthropic,应降级选择 Gemini 平台账户")
})
}
func TestGeminiMessagesCompatService_HasAntigravityAccounts(t *testing.T) {
ctx := context.Background()
t.Run("有antigravity账户时返回true", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
has, err := svc.HasAntigravityAccounts(ctx, nil)
require.NoError(t, err)
require.True(t, has)
})
t.Run("无antigravity账户时返回false", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
has, err := svc.HasAntigravityAccounts(ctx, nil)
require.NoError(t, err)
require.False(t, has)
})
t.Run("antigravity账户不可调度时返回false", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: false},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
has, err := svc.HasAntigravityAccounts(ctx, nil)
require.NoError(t, err)
require.False(t, has)
})
t.Run("带groupID查询", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
groupID := int64(1)
has, err := svc.HasAntigravityAccounts(ctx, &groupID)
require.NoError(t, err)
require.True(t, has)
})
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
// 该测试文档化了 Handler 层应该如何根据 account.Platform 进行分流
func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) {
tests := []struct {
name string
platform string
expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
}{
{
name: "Gemini平台走ForwardNative",
platform: PlatformGemini,
expectedService: "gemini",
},
{
name: "Antigravity平台走ForwardGemini",
platform: PlatformAntigravity,
expectedService: "antigravity",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: tt.platform}
// 模拟 Handler 层的路由逻辑
var serviceName string
if account.Platform == PlatformAntigravity {
serviceName = "antigravity"
} else {
serviceName = "gemini"
}
require.Equal(t, tt.expectedService, serviceName,
"平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService)
})
}
}
func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}
...@@ -27,6 +27,7 @@ func NewTokenRefreshService( ...@@ -27,6 +27,7 @@ func NewTokenRefreshService(
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cfg *config.Config, cfg *config.Config,
) *TokenRefreshService { ) *TokenRefreshService {
s := &TokenRefreshService{ s := &TokenRefreshService{
...@@ -40,6 +41,7 @@ func NewTokenRefreshService( ...@@ -40,6 +41,7 @@ func NewTokenRefreshService(
NewClaudeTokenRefresher(oauthService), NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService), NewOpenAITokenRefresher(openaiOAuthService),
NewGeminiTokenRefresher(geminiOAuthService), NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
} }
return s return s
......
...@@ -39,9 +39,10 @@ func ProvideTokenRefreshService( ...@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cfg *config.Config, cfg *config.Config,
) *TokenRefreshService { ) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
svc.Start() svc.Start()
return svc return svc
} }
...@@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet( ...@@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet(
NewAntigravityOAuthService, NewAntigravityOAuthService,
NewGeminiTokenProvider, NewGeminiTokenProvider,
NewGeminiMessagesCompatService, NewGeminiMessagesCompatService,
NewAntigravityTokenProvider,
NewAntigravityGatewayService,
NewRateLimitService, NewRateLimitService,
NewAccountUsageService, NewAccountUsageService,
NewAccountTestService, NewAccountTestService,
......
...@@ -62,6 +62,10 @@ const filteredGroups = computed(() => { ...@@ -62,6 +62,10 @@ const filteredGroups = computed(() => {
if (!props.platform) { if (!props.platform) {
return props.groups return props.groups
} }
// antigravity 账户可选择 anthropic 和 gemini 平台的分组
if (props.platform === 'antigravity') {
return props.groups.filter((g) => g.platform === 'anthropic' || g.platform === 'gemini')
}
return props.groups.filter((g) => g.platform === props.platform) return props.groups.filter((g) => g.platform === props.platform)
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment