Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
1d085d98
Commit
1d085d98
authored
Dec 28, 2025
by
song
Browse files
feat: 完善 Antigravity 多平台网关支持,修复 Gemini handler 分流逻辑
parent
6648e650
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
1d085d98
...
@@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService
:=
service
.
ProvideTimingWheelService
()
timingWheelService
:=
service
.
ProvideTimingWheelService
()
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
gatewayCache
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
)
antigravityTokenProvider
:=
service
.
NewAntigravityTokenProvider
(
accountRepository
,
geminiTokenCache
,
antigravityOAuthService
)
gatewayHandler
:=
handler
.
NewGatewayHandler
(
gatewayService
,
geminiMessagesCompatService
,
userService
,
concurrencyService
,
billingCacheService
)
antigravityGatewayService
:=
service
.
NewAntigravityGatewayService
(
accountRepository
,
gatewayCache
,
antigravityTokenProvider
,
rateLimitService
,
httpUpstream
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
gatewayCache
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
)
gatewayHandler
:=
handler
.
NewGatewayHandler
(
gatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
,
userService
,
concurrencyService
,
billingCacheService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
openAIGatewayHandler
:=
handler
.
NewOpenAIGatewayHandler
(
openAIGatewayService
,
concurrencyService
,
billingCacheService
)
openAIGatewayHandler
:=
handler
.
NewOpenAIGatewayHandler
(
openAIGatewayService
,
concurrencyService
,
billingCacheService
)
handlerSettingHandler
:=
handler
.
ProvideSettingHandler
(
settingService
,
buildInfo
)
handlerSettingHandler
:=
handler
.
ProvideSettingHandler
(
settingService
,
buildInfo
)
...
@@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthMiddleware
:=
middleware
.
NewApiKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
)
apiKeyAuthMiddleware
:=
middleware
.
NewApiKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
)
engine
:=
server
.
ProvideRouter
(
configConfig
,
handlers
,
jwtAuthMiddleware
,
adminAuthMiddleware
,
apiKeyAuthMiddleware
,
apiKeyService
,
subscriptionService
)
engine
:=
server
.
ProvideRouter
(
configConfig
,
handlers
,
jwtAuthMiddleware
,
adminAuthMiddleware
,
apiKeyAuthMiddleware
,
apiKeyService
,
subscriptionService
)
httpServer
:=
server
.
ProvideHTTPServer
(
configConfig
,
engine
)
httpServer
:=
server
.
ProvideHTTPServer
(
configConfig
,
engine
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
configConfig
)
v
:=
provideCleanup
(
db
,
client
,
tokenRefreshService
,
pricingService
,
emailQueueService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
v
:=
provideCleanup
(
db
,
client
,
tokenRefreshService
,
pricingService
,
emailQueueService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
application
:=
&
Application
{
application
:=
&
Application
{
Server
:
httpServer
,
Server
:
httpServer
,
...
...
backend/internal/handler/gateway_handler.go
View file @
1d085d98
...
@@ -21,27 +21,30 @@ import (
...
@@ -21,27 +21,30 @@ import (
// GatewayHandler handles API gateway requests
// GatewayHandler handles API gateway requests
type
GatewayHandler
struct
{
type
GatewayHandler
struct
{
gatewayService
*
service
.
GatewayService
gatewayService
*
service
.
GatewayService
geminiCompatService
*
service
.
GeminiMessagesCompatService
geminiCompatService
*
service
.
GeminiMessagesCompatService
userService
*
service
.
UserService
antigravityGatewayService
*
service
.
AntigravityGatewayService
billingCacheService
*
service
.
BillingCacheService
userService
*
service
.
UserService
concurrencyHelper
*
ConcurrencyHelper
billingCacheService
*
service
.
BillingCacheService
concurrencyHelper
*
ConcurrencyHelper
}
}
// NewGatewayHandler creates a new GatewayHandler
// NewGatewayHandler creates a new GatewayHandler
func
NewGatewayHandler
(
func
NewGatewayHandler
(
gatewayService
*
service
.
GatewayService
,
gatewayService
*
service
.
GatewayService
,
geminiCompatService
*
service
.
GeminiMessagesCompatService
,
geminiCompatService
*
service
.
GeminiMessagesCompatService
,
antigravityGatewayService
*
service
.
AntigravityGatewayService
,
userService
*
service
.
UserService
,
userService
*
service
.
UserService
,
concurrencyService
*
service
.
ConcurrencyService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
billingCacheService
*
service
.
BillingCacheService
,
)
*
GatewayHandler
{
)
*
GatewayHandler
{
return
&
GatewayHandler
{
return
&
GatewayHandler
{
gatewayService
:
gatewayService
,
gatewayService
:
gatewayService
,
geminiCompatService
:
geminiCompatService
,
geminiCompatService
:
geminiCompatService
,
userService
:
userService
,
antigravityGatewayService
:
antigravityGatewayService
,
billingCacheService
:
billingCacheService
,
userService
:
userService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
),
billingCacheService
:
billingCacheService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
),
}
}
}
}
...
@@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
return
}
}
// 转发请求
// 转发请求 - 根据账号平台分流
result
,
err
:=
h
.
geminiCompatService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
var
result
*
service
.
ForwardResult
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
Request
.
Context
(),
c
,
account
,
req
.
Model
,
"generateContent"
,
req
.
Stream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
}
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
...
@@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
return
}
}
// 转发请求
// 转发请求 - 根据账号平台分流
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
var
result
*
service
.
ForwardResult
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
}
else
{
result
,
err
=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
}
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
1d085d98
...
@@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
...
@@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
account
,
err
:=
h
.
geminiCompatService
.
SelectAccountForAIStudioEndpoints
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
account
,
err
:=
h
.
geminiCompatService
.
SelectAccountForAIStudioEndpoints
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
if
err
!=
nil
{
if
err
!=
nil
{
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity
,
_
:=
h
.
geminiCompatService
.
HasAntigravityAccounts
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
if
hasAntigravity
{
// antigravity 账户使用静态模型列表
c
.
JSON
(
http
.
StatusOK
,
gemini
.
FallbackModelsList
())
return
}
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
return
return
}
}
...
@@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
...
@@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
account
,
err
:=
h
.
geminiCompatService
.
SelectAccountForAIStudioEndpoints
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
account
,
err
:=
h
.
geminiCompatService
.
SelectAccountForAIStudioEndpoints
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
if
err
!=
nil
{
if
err
!=
nil
{
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity
,
_
:=
h
.
geminiCompatService
.
HasAntigravityAccounts
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
if
hasAntigravity
{
// antigravity 账户使用静态模型信息
c
.
JSON
(
http
.
StatusOK
,
gemini
.
FallbackModel
(
modelName
))
return
}
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
return
return
}
}
...
@@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
return
}
}
// 5) forward (writes response to client)
// 5) forward (根据平台分流)
result
,
err
:=
h
.
geminiCompatService
.
ForwardNative
(
c
.
Request
.
Context
(),
c
,
account
,
modelName
,
action
,
stream
,
body
)
var
result
*
service
.
ForwardResult
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
Request
.
Context
(),
c
,
account
,
modelName
,
action
,
stream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
c
.
Request
.
Context
(),
c
,
account
,
modelName
,
action
,
stream
,
body
)
}
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
...
...
backend/internal/handler/gemini_v1beta_handler_test.go
0 → 100644
View file @
1d085d98
//go:build unit
package
handler
import
(
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
func
TestGeminiV1BetaHandler_PlatformRoutingInvariant
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
platform
string
expectedService
string
description
string
}{
{
name
:
"Gemini平台使用ForwardNative"
,
platform
:
service
.
PlatformGemini
,
expectedService
:
"GeminiMessagesCompatService.ForwardNative"
,
description
:
"Gemini OAuth 账户直接调用 Google API"
,
},
{
name
:
"Antigravity平台使用ForwardGemini"
,
platform
:
service
.
PlatformAntigravity
,
expectedService
:
"AntigravityGatewayService.ForwardGemini"
,
description
:
"Antigravity 账户通过 CRS 中转,支持 Gemini 协议"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
var
routedService
string
if
tt
.
platform
==
service
.
PlatformAntigravity
{
routedService
=
"AntigravityGatewayService.ForwardGemini"
}
else
{
routedService
=
"GeminiMessagesCompatService.ForwardNative"
}
require
.
Equal
(
t
,
tt
.
expectedService
,
routedService
,
"平台 %s 应该路由到 %s: %s"
,
tt
.
platform
,
tt
.
expectedService
,
tt
.
description
)
})
}
}
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
func
TestGeminiV1BetaHandler_ListModelsAntigravityFallback
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
hasGeminiAccount
bool
hasAntigravity
bool
expectedBehavior
string
}{
{
name
:
"有Gemini账户-调用ForwardAIStudioGET"
,
hasGeminiAccount
:
true
,
hasAntigravity
:
false
,
expectedBehavior
:
"forward_to_upstream"
,
},
{
name
:
"无Gemini有Antigravity-返回静态列表"
,
hasGeminiAccount
:
false
,
hasAntigravity
:
true
,
expectedBehavior
:
"static_fallback"
,
},
{
name
:
"无任何账户-返回503"
,
hasGeminiAccount
:
false
,
hasAntigravity
:
false
,
expectedBehavior
:
"service_unavailable"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
var
behavior
string
if
tt
.
hasGeminiAccount
{
behavior
=
"forward_to_upstream"
}
else
if
tt
.
hasAntigravity
{
behavior
=
"static_fallback"
}
else
{
behavior
=
"service_unavailable"
}
require
.
Equal
(
t
,
tt
.
expectedBehavior
,
behavior
)
})
}
}
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
func
TestGeminiV1BetaHandler_GetModelAntigravityFallback
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
hasGeminiAccount
bool
hasAntigravity
bool
expectedBehavior
string
}{
{
name
:
"有Gemini账户-调用ForwardAIStudioGET"
,
hasGeminiAccount
:
true
,
hasAntigravity
:
false
,
expectedBehavior
:
"forward_to_upstream"
,
},
{
name
:
"无Gemini有Antigravity-返回静态模型信息"
,
hasGeminiAccount
:
false
,
hasAntigravity
:
true
,
expectedBehavior
:
"static_model_info"
,
},
{
name
:
"无任何账户-返回503"
,
hasGeminiAccount
:
false
,
hasAntigravity
:
false
,
expectedBehavior
:
"service_unavailable"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
var
behavior
string
if
tt
.
hasGeminiAccount
{
behavior
=
"forward_to_upstream"
}
else
if
tt
.
hasAntigravity
{
behavior
=
"static_model_info"
}
else
{
behavior
=
"service_unavailable"
}
require
.
Equal
(
t
,
tt
.
expectedBehavior
,
behavior
)
})
}
}
backend/internal/repository/account_repo.go
View file @
1d085d98
...
@@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
...
@@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
return
outAccounts
,
nil
return
outAccounts
,
nil
}
}
func
(
r
*
accountRepository
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
platforms
)
==
0
{
return
nil
,
nil
}
var
accounts
[]
accountModel
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"platform IN ?"
,
platforms
)
.
Where
(
"status = ? AND schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
func
(
r
*
accountRepository
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
platforms
)
==
0
{
return
nil
,
nil
}
var
accounts
[]
accountModel
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"accounts.platform IN ?"
,
platforms
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
func
(
r
*
accountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
func
(
r
*
accountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
accountModel
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
accountModel
{})
.
Where
(
"id = ?"
,
id
)
.
...
...
backend/internal/repository/gateway_routing_integration_test.go
0 → 100644
View file @
1d085d98
//go:build integration
package
repository
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
// 验证账户选择和分流逻辑在真实数据库环境下的行为
type
GatewayRoutingSuite
struct
{
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
accountRepo
*
accountRepository
}
func
(
s
*
GatewayRoutingSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
accountRepo
=
NewAccountRepository
(
s
.
db
)
.
(
*
accountRepository
)
}
func
TestGatewayRoutingSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GatewayRoutingSuite
))
}
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func
(
s
*
GatewayRoutingSuite
)
TestListSchedulableByPlatforms_GeminiAndAntigravity
()
{
// 创建各平台账户
geminiAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"gemini-oauth"
,
Platform
:
service
.
PlatformGemini
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Priority
:
1
,
})
antigravityAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"antigravity-oauth"
,
Platform
:
service
.
PlatformAntigravity
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Priority
:
2
,
Credentials
:
datatypes
.
JSONMap
{
"access_token"
:
"test-token"
,
"refresh_token"
:
"test-refresh"
,
"project_id"
:
"test-project"
,
},
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"anthropic-oauth"
,
Platform
:
service
.
PlatformAnthropic
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Priority
:
0
,
})
// 查询 gemini + antigravity 平台
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
s
.
ctx
,
[]
string
{
service
.
PlatformGemini
,
service
.
PlatformAntigravity
,
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
2
,
"应返回 gemini 和 antigravity 两个账户"
)
// 验证返回的账户平台
platforms
:=
make
(
map
[
string
]
bool
)
for
_
,
acc
:=
range
accounts
{
platforms
[
acc
.
Platform
]
=
true
}
s
.
Require
()
.
True
(
platforms
[
service
.
PlatformGemini
],
"应包含 gemini 账户"
)
s
.
Require
()
.
True
(
platforms
[
service
.
PlatformAntigravity
],
"应包含 antigravity 账户"
)
s
.
Require
()
.
False
(
platforms
[
service
.
PlatformAnthropic
],
"不应包含 anthropic 账户"
)
// 验证账户 ID 匹配
ids
:=
make
(
map
[
int64
]
bool
)
for
_
,
acc
:=
range
accounts
{
ids
[
acc
.
ID
]
=
true
}
s
.
Require
()
.
True
(
ids
[
geminiAcc
.
ID
])
s
.
Require
()
.
True
(
ids
[
antigravityAcc
.
ID
])
}
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func
(
s
*
GatewayRoutingSuite
)
TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding
()
{
// 创建 gemini 分组
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"gemini-group"
,
Platform
:
service
.
PlatformGemini
,
Status
:
service
.
StatusActive
,
})
// 创建账户
boundAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"bound-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
unboundAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"unbound-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
// 只绑定一个账户到分组
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
boundAcc
.
ID
,
group
.
ID
,
1
)
// 查询分组内的账户
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
s
.
ctx
,
group
.
ID
,
[]
string
{
service
.
PlatformGemini
,
service
.
PlatformAntigravity
,
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
,
"应只返回绑定到分组的账户"
)
s
.
Require
()
.
Equal
(
boundAcc
.
ID
,
accounts
[
0
]
.
ID
)
// 确认未绑定的账户不在结果中
for
_
,
acc
:=
range
accounts
{
s
.
Require
()
.
NotEqual
(
unboundAcc
.
ID
,
acc
.
ID
,
"不应包含未绑定的账户"
)
}
}
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func
(
s
*
GatewayRoutingSuite
)
TestListSchedulableByPlatform_Antigravity
()
{
// 创建多种平台账户
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"gemini-1"
,
Platform
:
service
.
PlatformGemini
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
antigravity
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"antigravity-1"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
// 只查询 antigravity 平台
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByPlatform
(
s
.
ctx
,
service
.
PlatformAntigravity
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Equal
(
antigravity
.
ID
,
accounts
[
0
]
.
ID
)
s
.
Require
()
.
Equal
(
service
.
PlatformAntigravity
,
accounts
[
0
]
.
Platform
)
}
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func
(
s
*
GatewayRoutingSuite
)
TestSchedulableFilter_ExcludesInactive
()
{
// 创建可调度账户
activeAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"active-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
inactiveAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"inactive-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
})
s
.
Require
()
.
NoError
(
s
.
db
.
Model
(
&
accountModel
{})
.
Where
(
"id = ?"
,
inactiveAcc
.
ID
)
.
Update
(
"schedulable"
,
false
)
.
Error
)
// 创建错误状态账户
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"error-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusError
,
Schedulable
:
true
,
})
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByPlatform
(
s
.
ctx
,
service
.
PlatformAntigravity
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
,
"应只返回可调度的 active 账户"
)
s
.
Require
()
.
Equal
(
activeAcc
.
ID
,
accounts
[
0
]
.
ID
)
}
// TestPlatformRoutingDecision 验证平台路由决策
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func
(
s
*
GatewayRoutingSuite
)
TestPlatformRoutingDecision
()
{
// 创建两种平台的账户
geminiAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"gemini-route-test"
,
Platform
:
service
.
PlatformGemini
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
antigravityAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"antigravity-route-test"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
tests
:=
[]
struct
{
name
string
accountID
int64
expectedService
string
}{
{
name
:
"Gemini账户路由到ForwardNative"
,
accountID
:
geminiAcc
.
ID
,
expectedService
:
"GeminiMessagesCompatService.ForwardNative"
,
},
{
name
:
"Antigravity账户路由到ForwardGemini"
,
accountID
:
antigravityAcc
.
ID
,
expectedService
:
"AntigravityGatewayService.ForwardGemini"
,
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
// 从数据库获取账户
account
,
err
:=
s
.
accountRepo
.
GetByID
(
s
.
ctx
,
tt
.
accountID
)
s
.
Require
()
.
NoError
(
err
)
// 模拟 Handler 层的路由决策
var
routedService
string
if
account
.
Platform
==
service
.
PlatformAntigravity
{
routedService
=
"AntigravityGatewayService.ForwardGemini"
}
else
{
routedService
=
"GeminiMessagesCompatService.ForwardNative"
}
s
.
Require
()
.
Equal
(
tt
.
expectedService
,
routedService
)
})
}
}
backend/internal/service/account_service.go
View file @
1d085d98
...
@@ -38,6 +38,8 @@ type AccountRepository interface {
...
@@ -38,6 +38,8 @@ type AccountRepository interface {
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
...
...
backend/internal/service/antigravity_gateway_service.go
0 → 100644
View file @
1d085d98
This diff is collapsed.
Click to expand it.
backend/internal/service/antigravity_model_mapping_test.go
0 → 100644
View file @
1d085d98
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestIsAntigravityModelSupported
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
model
string
expected
bool
}{
// 直接支持的模型
{
"直接支持 - claude-sonnet-4-5"
,
"claude-sonnet-4-5"
,
true
},
{
"直接支持 - claude-opus-4-5-thinking"
,
"claude-opus-4-5-thinking"
,
true
},
{
"直接支持 - claude-sonnet-4-5-thinking"
,
"claude-sonnet-4-5-thinking"
,
true
},
{
"直接支持 - gemini-2.5-flash"
,
"gemini-2.5-flash"
,
true
},
{
"直接支持 - gemini-2.5-flash-lite"
,
"gemini-2.5-flash-lite"
,
true
},
{
"直接支持 - gemini-3-pro-high"
,
"gemini-3-pro-high"
,
true
},
// 可映射的模型
{
"可映射 - claude-3-5-sonnet-20241022"
,
"claude-3-5-sonnet-20241022"
,
true
},
{
"可映射 - claude-3-5-sonnet-20240620"
,
"claude-3-5-sonnet-20240620"
,
true
},
{
"可映射 - claude-opus-4"
,
"claude-opus-4"
,
true
},
{
"可映射 - claude-haiku-4"
,
"claude-haiku-4"
,
true
},
{
"可映射 - claude-3-haiku-20240307"
,
"claude-3-haiku-20240307"
,
true
},
// Gemini 前缀透传
{
"Gemini前缀 - gemini-1.5-pro"
,
"gemini-1.5-pro"
,
true
},
{
"Gemini前缀 - gemini-unknown-model"
,
"gemini-unknown-model"
,
true
},
{
"Gemini前缀 - gemini-future-version"
,
"gemini-future-version"
,
true
},
// Claude 前缀兜底
{
"Claude前缀 - claude-unknown-model"
,
"claude-unknown-model"
,
true
},
{
"Claude前缀 - claude-3-opus-20240229"
,
"claude-3-opus-20240229"
,
true
},
{
"Claude前缀 - claude-future-version"
,
"claude-future-version"
,
true
},
// 不支持的模型
{
"不支持 - gpt-4"
,
"gpt-4"
,
false
},
{
"不支持 - gpt-4o"
,
"gpt-4o"
,
false
},
{
"不支持 - llama-3"
,
"llama-3"
,
false
},
{
"不支持 - mistral-7b"
,
"mistral-7b"
,
false
},
{
"不支持 - 空字符串"
,
""
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
IsAntigravityModelSupported
(
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
,
"model: %s"
,
tt
.
model
)
})
}
}
func
TestAntigravityGatewayService_GetMappedModel
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
requestedModel
string
accountMapping
map
[
string
]
string
expected
string
}{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
{
name
:
"账户映射优先"
,
requestedModel
:
"claude-3-5-sonnet-20241022"
,
accountMapping
:
map
[
string
]
string
{
"claude-3-5-sonnet-20241022"
:
"custom-model"
},
expected
:
"custom-model"
,
},
{
name
:
"账户映射覆盖系统映射"
,
requestedModel
:
"claude-opus-4"
,
accountMapping
:
map
[
string
]
string
{
"claude-opus-4"
:
"my-opus"
},
expected
:
"my-opus"
,
},
// 2. 系统默认映射
{
name
:
"系统映射 - claude-3-5-sonnet-20241022"
,
requestedModel
:
"claude-3-5-sonnet-20241022"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"系统映射 - claude-3-5-sonnet-20240620"
,
requestedModel
:
"claude-3-5-sonnet-20240620"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"系统映射 - claude-opus-4"
,
requestedModel
:
"claude-opus-4"
,
accountMapping
:
nil
,
expected
:
"claude-opus-4-5-thinking"
,
},
{
name
:
"系统映射 - claude-opus-4-5-20251101"
,
requestedModel
:
"claude-opus-4-5-20251101"
,
accountMapping
:
nil
,
expected
:
"claude-opus-4-5-thinking"
,
},
{
name
:
"系统映射 - claude-haiku-4"
,
requestedModel
:
"claude-haiku-4"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"系统映射 - claude-3-haiku-20240307"
,
requestedModel
:
"claude-3-haiku-20240307"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"系统映射 - claude-sonnet-4-5-20250929"
,
requestedModel
:
"claude-sonnet-4-5-20250929"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5-thinking"
,
},
// 3. Gemini 透传
{
name
:
"Gemini透传 - gemini-2.5-flash"
,
requestedModel
:
"gemini-2.5-flash"
,
accountMapping
:
nil
,
expected
:
"gemini-2.5-flash"
,
},
{
name
:
"Gemini透传 - gemini-1.5-pro"
,
requestedModel
:
"gemini-1.5-pro"
,
accountMapping
:
nil
,
expected
:
"gemini-1.5-pro"
,
},
{
name
:
"Gemini透传 - gemini-future-model"
,
requestedModel
:
"gemini-future-model"
,
accountMapping
:
nil
,
expected
:
"gemini-future-model"
,
},
// 4. 直接支持的模型
{
name
:
"直接支持 - claude-sonnet-4-5"
,
requestedModel
:
"claude-sonnet-4-5"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"直接支持 - claude-opus-4-5-thinking"
,
requestedModel
:
"claude-opus-4-5-thinking"
,
accountMapping
:
nil
,
expected
:
"claude-opus-4-5-thinking"
,
},
{
name
:
"直接支持 - claude-sonnet-4-5-thinking"
,
requestedModel
:
"claude-sonnet-4-5-thinking"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5-thinking"
,
},
// 5. 默认值 fallback(未知 claude 模型)
{
name
:
"默认值 - claude-unknown"
,
requestedModel
:
"claude-unknown"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"默认值 - claude-3-opus-20240229"
,
requestedModel
:
"claude-3-opus-20240229"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAntigravity
,
}
if
tt
.
accountMapping
!=
nil
{
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
mappingAny
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
tt
.
accountMapping
{
mappingAny
[
k
]
=
v
}
account
.
Credentials
=
map
[
string
]
any
{
"model_mapping"
:
mappingAny
,
}
}
got
:=
svc
.
getMappedModel
(
account
,
tt
.
requestedModel
)
require
.
Equal
(
t
,
tt
.
expected
,
got
,
"model: %s"
,
tt
.
requestedModel
)
})
}
}
func
TestAntigravityGatewayService_GetMappedModel_EdgeCases
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
requestedModel
string
expected
string
}{
// 空字符串回退到默认值
{
"空字符串"
,
""
,
"claude-sonnet-4-5"
},
// 非 claude/gemini 前缀回退到默认值
{
"非claude/gemini前缀 - gpt"
,
"gpt-4"
,
"claude-sonnet-4-5"
},
{
"非claude/gemini前缀 - llama"
,
"llama-3"
,
"claude-sonnet-4-5"
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAntigravity
}
got
:=
svc
.
getMappedModel
(
account
,
tt
.
requestedModel
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
func
TestAntigravityGatewayService_IsModelSupported
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
model
string
expected
bool
}{
// 直接支持
{
"直接支持 - claude-sonnet-4-5"
,
"claude-sonnet-4-5"
,
true
},
{
"直接支持 - gemini-3-flash"
,
"gemini-3-flash"
,
true
},
// 可映射
{
"可映射 - claude-opus-4"
,
"claude-opus-4"
,
true
},
// 前缀透传
{
"Gemini前缀"
,
"gemini-unknown"
,
true
},
{
"Claude前缀"
,
"claude-unknown"
,
true
},
// 不支持
{
"不支持 - gpt-4"
,
"gpt-4"
,
false
},
{
"不支持 - 空字符串"
,
""
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
IsModelSupported
(
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
backend/internal/service/antigravity_token_provider.go
0 → 100644
View file @
1d085d98
package
service
import
(
"context"
"errors"
"log"
"strconv"
"strings"
"time"
)
const
(
antigravityTokenRefreshSkew
=
3
*
time
.
Minute
antigravityTokenCacheSkew
=
5
*
time
.
Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
AntigravityTokenCache
=
GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
type
AntigravityTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
AntigravityTokenCache
antigravityOAuthService
*
AntigravityOAuthService
}
func
NewAntigravityTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
AntigravityTokenCache
,
antigravityOAuthService
*
AntigravityOAuthService
,
)
*
AntigravityTokenProvider
{
return
&
AntigravityTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
antigravityOAuthService
:
antigravityOAuthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
AntigravityTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAntigravity
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an antigravity oauth account"
)
}
cacheKey
:=
antigravityTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
}
// 2. 如果即将过期则刷新
expiresAt
:=
parseAntigravityExpiresAt
(
account
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
parseAntigravityExpiresAt
(
account
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
{
if
p
.
antigravityOAuthService
==
nil
{
return
""
,
errors
.
New
(
"antigravity oauth service not configured"
)
}
tokenInfo
,
err
:=
p
.
antigravityOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
newCredentials
:=
p
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
log
.
Printf
(
"[AntigravityTokenProvider] Failed to update account credentials: %v"
,
updateErr
)
}
expiresAt
=
parseAntigravityExpiresAt
(
account
)
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
antigravityTokenCacheSkew
:
ttl
=
until
-
antigravityTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
antigravityTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
"ag:"
+
projectID
}
return
"ag:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
func
parseAntigravityExpiresAt
(
account
*
Account
)
*
time
.
Time
{
raw
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"expires_at"
))
if
raw
==
""
{
return
nil
}
if
unixSec
,
err
:=
strconv
.
ParseInt
(
raw
,
10
,
64
);
err
==
nil
&&
unixSec
>
0
{
t
:=
time
.
Unix
(
unixSec
,
0
)
return
&
t
}
if
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
raw
);
err
==
nil
{
return
&
t
}
return
nil
}
backend/internal/service/antigravity_token_refresher.go
0 → 100644
View file @
1d085d98
package
service
import
(
"context"
"strconv"
"time"
)
// AntigravityTokenRefresher 实现 TokenRefresher 接口
type
AntigravityTokenRefresher
struct
{
antigravityOAuthService
*
AntigravityOAuthService
}
func
NewAntigravityTokenRefresher
(
antigravityOAuthService
*
AntigravityOAuthService
)
*
AntigravityTokenRefresher
{
return
&
AntigravityTokenRefresher
{
antigravityOAuthService
:
antigravityOAuthService
,
}
}
// CanRefresh 检查是否可以刷新此账户
func
(
r
*
AntigravityTokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformAntigravity
&&
account
.
Type
==
AccountTypeOAuth
}
// NeedsRefresh 检查账户是否需要刷新
func
(
r
*
AntigravityTokenRefresher
)
NeedsRefresh
(
account
*
Account
,
refreshWindow
time
.
Duration
)
bool
{
if
!
r
.
CanRefresh
(
account
)
{
return
false
}
expiresAtStr
:=
account
.
GetCredential
(
"expires_at"
)
if
expiresAtStr
==
""
{
return
false
}
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresAtStr
,
10
,
64
)
if
err
!=
nil
{
return
false
}
expiryTime
:=
time
.
Unix
(
expiresAt
,
0
)
return
time
.
Until
(
expiryTime
)
<
refreshWindow
}
// Refresh 执行 token 刷新
func
(
r
*
AntigravityTokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
tokenInfo
,
err
:=
r
.
antigravityOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
return
newCredentials
,
nil
}
backend/internal/service/gateway_multiplatform_test.go
0 → 100644
View file @
1d085d98
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForMultiplatform 多平台测试用的 mock
type
mockAccountRepoForMultiplatform
struct
{
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
listPlatformsFunc
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
}
func
(
m
*
mockAccountRepoForMultiplatform
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
if
acc
,
ok
:=
m
.
accountsByID
[
id
];
ok
{
return
acc
,
nil
}
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
if
m
.
listPlatformsFunc
!=
nil
{
return
m
.
listPlatformsFunc
(
ctx
,
platforms
)
}
// 过滤符合平台的账户
var
result
[]
Account
platformSet
:=
make
(
map
[
string
]
bool
)
for
_
,
p
:=
range
platforms
{
platformSet
[
p
]
=
true
}
for
_
,
acc
:=
range
m
.
accounts
{
if
platformSet
[
acc
.
Platform
]
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
m
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
// Stub methods to implement AccountRepository interface
func
(
m
*
mockAccountRepoForMultiplatform
)
Create
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListSchedulable
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForMultiplatform
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
// Verify interface implementation
var
_
AccountRepository
=
(
*
mockAccountRepoForMultiplatform
)(
nil
)
// mockGatewayCacheForMultiplatform 多平台测试用的 cache mock
type
mockGatewayCacheForMultiplatform
struct
{
sessionBindings
map
[
string
]
int64
}
func
(
m
*
mockGatewayCacheForMultiplatform
)
GetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
)
(
int64
,
error
)
{
if
id
,
ok
:=
m
.
sessionBindings
[
sessionHash
];
ok
{
return
id
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
m
*
mockGatewayCacheForMultiplatform
)
SetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
if
m
.
sessionBindings
==
nil
{
m
.
sessionBindings
=
make
(
map
[
string
]
int64
)
}
m
.
sessionBindings
[
sessionHash
]
=
accountID
return
nil
}
func
(
m
*
mockGatewayCacheForMultiplatform
)
RefreshSessionTTL
(
ctx
context
.
Context
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
func
TestGatewayService_SelectAccountForModelWithExclusions_OnlyAnthropic
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应选择优先级最高的账户"
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_OnlyAntigravity
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
ptr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
ptr
(
now
.
Add
(
-
2
*
time
.
Hour
))},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应选择最久未用的账户(Antigravity)"
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_DiffPriority
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应选择优先级更高的账户(Antigravity, priority=1)"
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_ModelNotSupported
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
// Anthropic 账户配置了模型映射,只支持 other-model
// 注意:model_mapping 需要是 map[string]any 格式
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"other-model"
:
"x"
}},
},
// Antigravity 账户支持所有 claude 模型
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"Anthropic 不支持该模型,应选择 Antigravity"
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_NoAvailableAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_AllExcluded
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{},
2
:
{}}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
}
func
TestGatewayService_SelectAccountForModelWithExclusions_Schedulability
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
tests
:=
[]
struct
{
name
string
accounts
[]
Account
expectedID
int64
}{
{
name
:
"过载账户被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
OverloadUntil
:
ptr
(
now
.
Add
(
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"限流账户被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
ptr
(
now
.
Add
(
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"非active账户被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
"error"
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"schedulable=false被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
false
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"过期的过载账户可调度"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
OverloadUntil
:
ptr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
1
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
tt
.
accounts
,
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
tt
.
expectedID
,
acc
.
ID
)
})
}
}
func
TestGatewayService_SelectAccountForModelWithExclusions_StickySession
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"粘性会话命中"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应返回粘性会话绑定的账户"
)
})
t
.
Run
(
"粘性会话账户被排除-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户被排除,应选择其他账户"
)
})
t
.
Run
(
"粘性会话账户不可调度-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForMultiplatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
"error"
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForMultiplatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatforms
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户不可调度,应选择其他账户"
)
})
}
func
TestGatewayService_isModelSupportedByAccount
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
tests
:=
[]
struct
{
name
string
account
*
Account
model
string
expected
bool
}{
{
name
:
"Antigravity平台-支持claude模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-支持gemini模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-不支持gpt模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gpt-4"
,
expected
:
false
,
},
{
name
:
"Anthropic平台-无映射配置-支持所有模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Anthropic平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-opus-4"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
false
,
},
{
name
:
"Anthropic平台-有映射配置-支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet-20241022"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
isModelSupportedByAccount
(
tt
.
account
,
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
backend/internal/service/gateway_service.go
View file @
1d085d98
...
@@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
...
@@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func
(
s
*
GatewayService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
func
(
s
*
GatewayService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 使用多平台账户选择,包含 anthropic 和 antigravity 平台
platforms
:=
[]
string
{
PlatformAnthropic
,
PlatformAntigravity
}
return
s
.
selectAccountForModelWithPlatforms
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platforms
)
}
// selectAccountForModelWithPlatforms 选择多平台账户
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatforms
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platforms
[]
string
)
(
*
Account
,
error
)
{
// 1. 查询粘性会话
// 1. 查询粘性会话
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
sessionHash
)
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
sessionHash
)
...
@@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
// 同时检查模型支持
// 同时检查模型支持
(根据平台类型分别处理)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
account
.
I
sModelSupported
(
requestedModel
))
{
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
s
.
i
sModelSupported
ByAccount
(
account
,
requestedModel
))
{
// 续期粘性会话
// 续期粘性会话
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
...
@@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
}
}
}
// 2. 获取可调度账号列表(排除限流和过载的账号,
仅限 Anthropic
平台)
// 2. 获取可调度账号列表(排除限流和过载的账号,
支持多
平台)
var
accounts
[]
Account
var
accounts
[]
Account
var
err
error
var
err
error
if
groupID
!=
nil
{
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
P
latform
Anthropic
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
s
(
ctx
,
*
groupID
,
p
latform
s
)
}
else
{
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
P
latform
Anthropic
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
s
(
ctx
,
p
latform
s
)
}
}
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
...
@@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
continue
}
}
// 检查模型支持
// 检查模型支持
(根据平台类型分别处理)
if
requestedModel
!=
""
&&
!
acc
.
I
sModelSupported
(
requestedModel
)
{
if
requestedModel
!=
""
&&
!
s
.
i
sModelSupported
ByAccount
(
acc
,
requestedModel
)
{
continue
continue
}
}
if
selected
==
nil
{
if
selected
==
nil
{
...
@@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return
selected
,
nil
return
selected
,
nil
}
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func
(
s
*
GatewayService
)
isModelSupportedByAccount
(
account
*
Account
,
requestedModel
string
)
bool
{
if
account
.
Platform
==
PlatformAntigravity
{
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
}
// 其他平台使用账户的模型支持检查
return
account
.
IsModelSupported
(
requestedModel
)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
func
IsAntigravityModelSupported
(
requestedModel
string
)
bool
{
// 直接支持的模型
if
antigravitySupportedModels
[
requestedModel
]
{
return
true
}
// 可映射的模型
if
_
,
ok
:=
antigravityModelMapping
[
requestedModel
];
ok
{
return
true
}
// Gemini 前缀透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
if
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
{
return
true
}
return
false
}
// GetAccessToken 获取账号凭证
// GetAccessToken 获取账号凭证
func
(
s
*
GatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
func
(
s
*
GatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
switch
account
.
Type
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
1d085d98
...
@@ -33,11 +33,12 @@ const (
...
@@ -33,11 +33,12 @@ const (
)
)
type
GeminiMessagesCompatService
struct
{
type
GeminiMessagesCompatService
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
cache
GatewayCache
cache
GatewayCache
tokenProvider
*
GeminiTokenProvider
tokenProvider
*
GeminiTokenProvider
rateLimitService
*
RateLimitService
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
antigravityGatewayService
*
AntigravityGatewayService
}
}
func
NewGeminiMessagesCompatService
(
func
NewGeminiMessagesCompatService
(
...
@@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService(
...
@@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService(
tokenProvider
*
GeminiTokenProvider
,
tokenProvider
*
GeminiTokenProvider
,
rateLimitService
*
RateLimitService
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
antigravityGatewayService
*
AntigravityGatewayService
,
)
*
GeminiMessagesCompatService
{
)
*
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
cache
:
cache
,
cache
:
cache
,
tokenProvider
:
tokenProvider
,
tokenProvider
:
tokenProvider
,
rateLimitService
:
rateLimitService
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
httpUpstream
:
httpUpstream
,
antigravityGatewayService
:
antigravityGatewayService
,
}
}
}
}
...
@@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
...
@@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
func
(
s
*
GeminiMessagesCompatService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
func
(
s
*
GeminiMessagesCompatService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
cacheKey
:=
"gemini:"
+
sessionHash
cacheKey
:=
"gemini:"
+
sessionHash
platforms
:=
[]
string
{
PlatformGemini
,
PlatformAntigravity
}
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
cacheKey
)
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
cacheKey
)
if
err
==
nil
&&
accountID
>
0
{
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
Platform
==
PlatformGemini
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// 支持 gemini 和 antigravity 平台的粘性会话
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
account
.
Platform
==
PlatformGemini
||
account
.
Platform
==
PlatformAntigravity
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
cacheKey
,
geminiStickySessionTTL
)
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
cacheKey
,
geminiStickySessionTTL
)
return
account
,
nil
return
account
,
nil
}
}
...
@@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
}
}
// 同时查询 gemini 和 antigravity 平台的可调度账户
var
accounts
[]
Account
var
accounts
[]
Account
var
err
error
var
err
error
if
groupID
!=
nil
{
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
P
latform
Gemini
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
s
(
ctx
,
*
groupID
,
p
latform
s
)
}
else
{
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
P
latform
Gemini
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
s
(
ctx
,
p
latform
s
)
}
}
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
...
@@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
continue
}
}
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
// 根据平台类型分别检查模型支持
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
continue
}
}
if
selected
==
nil
{
if
selected
==
nil
{
...
@@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if
selected
==
nil
{
if
selected
==
nil
{
if
requestedModel
!=
""
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available Gemini accounts supporting model: %s"
,
requestedModel
)
return
nil
,
fmt
.
Errorf
(
"no available Gemini
/Antigravity
accounts supporting model: %s"
,
requestedModel
)
}
}
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
return
nil
,
errors
.
New
(
"no available Gemini
/Antigravity
accounts"
)
}
}
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
...
@@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return
selected
,
nil
return
selected
,
nil
}
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func
(
s
*
GeminiMessagesCompatService
)
isModelSupportedByAccount
(
account
*
Account
,
requestedModel
string
)
bool
{
if
account
.
Platform
==
PlatformAntigravity
{
return
IsAntigravityModelSupported
(
requestedModel
)
}
return
account
.
IsModelSupported
(
requestedModel
)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func
(
s
*
GeminiMessagesCompatService
)
GetAntigravityGatewayService
()
*
AntigravityGatewayService
{
return
s
.
antigravityGatewayService
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformAntigravity
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformAntigravity
)
}
if
err
!=
nil
{
return
false
,
err
}
return
len
(
accounts
)
>
0
,
nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
//
...
...
backend/internal/service/gemini_multiplatform_test.go
0 → 100644
View file @
1d085d98
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForGemini Gemini 测试用的 mock
type
mockAccountRepoForGemini
struct
{
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
}
func
(
m
*
mockAccountRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
if
acc
,
ok
:=
m
.
accountsByID
[
id
];
ok
{
return
acc
,
nil
}
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
platformSet
:=
make
(
map
[
string
]
bool
)
for
_
,
p
:=
range
platforms
{
platformSet
[
p
]
=
true
}
var
result
[]
Account
for
_
,
acc
:=
range
m
.
accounts
{
if
platformSet
[
acc
.
Platform
]
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
m
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
// Stub methods to implement AccountRepository interface
func
(
m
*
mockAccountRepoForGemini
)
Create
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulable
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
var
result
[]
Account
for
_
,
acc
:=
range
m
.
accounts
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
// 测试时不区分 groupID,直接按 platform 过滤
return
m
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
func
(
m
*
mockAccountRepoForGemini
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
// Verify interface implementation
var
_
AccountRepository
=
(
*
mockAccountRepoForGemini
)(
nil
)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type
mockGatewayCacheForGemini
struct
{
sessionBindings
map
[
string
]
int64
}
func
(
m
*
mockGatewayCacheForGemini
)
GetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
)
(
int64
,
error
)
{
if
id
,
ok
:=
m
.
sessionBindings
[
sessionHash
];
ok
{
return
id
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
m
*
mockGatewayCacheForGemini
)
SetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
if
m
.
sessionBindings
==
nil
{
m
.
sessionBindings
=
make
(
map
[
string
]
int64
)
}
m
.
sessionBindings
[
sessionHash
]
=
accountID
return
nil
}
func
(
m
*
mockGatewayCacheForGemini
)
RefreshSessionTTL
(
ctx
context
.
Context
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyGemini
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应选择优先级最高的账户"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyAntigravity
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludesAnthropic
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
3
,
Platform
:
PlatformAntigravity
,
Priority
:
3
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
// Anthropic 不在 [gemini, antigravity] 平台列表中,应被过滤
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"Anthropic 平台应被排除,选择 Gemini"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
ptr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
ptr
(
now
.
Add
(
-
2
*
time
.
Hour
))},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应选择最久未用的账户(Antigravity)"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeApiKey
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
nil
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
nil
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且都未使用时,应优先选择 OAuth 账户"
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
acc
.
Type
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred_MixedPlatforms
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeApiKey
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
nil
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
nil
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"跨平台时,同样优先选择 OAuth 账户"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForGemini
{}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available Gemini/Antigravity accounts"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"粘性会话命中-使用gemini前缀缓存键"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
// 注意:缓存键使用 "gemini:" 前缀
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-123"
:
1
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应返回粘性会话绑定的账户"
)
})
t
.
Run
(
"粘性会话不命中无前缀缓存键"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
// 缓存键没有 "gemini:" 前缀,不应命中
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
// 粘性会话未命中,按优先级选择
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话未命中,应按优先级选择 Antigravity"
)
})
t
.
Run
(
"粘性会话Anthropic账户-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-123"
:
1
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
// 粘性会话绑定的是 Anthropic 账户,不在 Gemini 平台列表中,应降级选择
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户是 Anthropic,应降级选择 Gemini 平台账户"
)
})
}
func
TestGeminiMessagesCompatService_HasAntigravityAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"有antigravity账户时返回true"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
}
has
,
err
:=
svc
.
HasAntigravityAccounts
(
ctx
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
has
)
})
t
.
Run
(
"无antigravity账户时返回false"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
}
has
,
err
:=
svc
.
HasAntigravityAccounts
(
ctx
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
has
)
})
t
.
Run
(
"antigravity账户不可调度时返回false"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Status
:
StatusActive
,
Schedulable
:
false
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
}
has
,
err
:=
svc
.
HasAntigravityAccounts
(
ctx
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
has
)
})
t
.
Run
(
"带groupID查询"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
}
groupID
:=
int64
(
1
)
has
,
err
:=
svc
.
HasAntigravityAccounts
(
ctx
,
&
groupID
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
has
)
})
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
// 该测试文档化了 Handler 层应该如何根据 account.Platform 进行分流
func
TestGeminiPlatformRouting_DocumentRouteDecision
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
platform
string
expectedService
string
// "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
}{
{
name
:
"Gemini平台走ForwardNative"
,
platform
:
PlatformGemini
,
expectedService
:
"gemini"
,
},
{
name
:
"Antigravity平台走ForwardGemini"
,
platform
:
PlatformAntigravity
,
expectedService
:
"antigravity"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
tt
.
platform
}
// 模拟 Handler 层的路由逻辑
var
serviceName
string
if
account
.
Platform
==
PlatformAntigravity
{
serviceName
=
"antigravity"
}
else
{
serviceName
=
"gemini"
}
require
.
Equal
(
t
,
tt
.
expectedService
,
serviceName
,
"平台 %s 应该路由到 %s 服务"
,
tt
.
platform
,
tt
.
expectedService
)
})
}
}
func
TestGeminiMessagesCompatService_isModelSupportedByAccount
(
t
*
testing
.
T
)
{
svc
:=
&
GeminiMessagesCompatService
{}
tests
:=
[]
struct
{
name
string
account
*
Account
model
string
expected
bool
}{
{
name
:
"Antigravity平台-支持gemini模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-支持claude模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-不支持gpt模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gpt-4"
,
expected
:
false
,
},
{
name
:
"Gemini平台-无映射配置-支持所有模型"
,
account
:
&
Account
{
Platform
:
PlatformGemini
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Gemini平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-1.5-pro"
:
"x"
}},
},
model
:
"gemini-2.5-flash"
,
expected
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
isModelSupportedByAccount
(
tt
.
account
,
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
backend/internal/service/token_refresh_service.go
View file @
1d085d98
...
@@ -27,6 +27,7 @@ func NewTokenRefreshService(
...
@@ -27,6 +27,7 @@ func NewTokenRefreshService(
oauthService
*
OAuthService
,
oauthService
*
OAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
)
*
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
...
@@ -40,6 +41,7 @@ func NewTokenRefreshService(
...
@@ -40,6 +41,7 @@ func NewTokenRefreshService(
NewClaudeTokenRefresher
(
oauthService
),
NewClaudeTokenRefresher
(
oauthService
),
NewOpenAITokenRefresher
(
openaiOAuthService
),
NewOpenAITokenRefresher
(
openaiOAuthService
),
NewGeminiTokenRefresher
(
geminiOAuthService
),
NewGeminiTokenRefresher
(
geminiOAuthService
),
NewAntigravityTokenRefresher
(
antigravityOAuthService
),
}
}
return
s
return
s
...
...
backend/internal/service/wire.go
View file @
1d085d98
...
@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
...
@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
oauthService
*
OAuthService
,
oauthService
*
OAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
)
*
TokenRefreshService
{
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
cfg
)
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cfg
)
svc
.
Start
()
svc
.
Start
()
return
svc
return
svc
}
}
...
@@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet(
...
@@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet(
NewAntigravityOAuthService
,
NewAntigravityOAuthService
,
NewGeminiTokenProvider
,
NewGeminiTokenProvider
,
NewGeminiMessagesCompatService
,
NewGeminiMessagesCompatService
,
NewAntigravityTokenProvider
,
NewAntigravityGatewayService
,
NewRateLimitService
,
NewRateLimitService
,
NewAccountUsageService
,
NewAccountUsageService
,
NewAccountTestService
,
NewAccountTestService
,
...
...
frontend/src/components/common/GroupSelector.vue
View file @
1d085d98
...
@@ -62,6 +62,10 @@ const filteredGroups = computed(() => {
...
@@ -62,6 +62,10 @@ const filteredGroups = computed(() => {
if
(
!
props
.
platform
)
{
if
(
!
props
.
platform
)
{
return
props
.
groups
return
props
.
groups
}
}
// antigravity 账户可选择 anthropic 和 gemini 平台的分组
if
(
props
.
platform
===
'
antigravity
'
)
{
return
props
.
groups
.
filter
((
g
)
=>
g
.
platform
===
'
anthropic
'
||
g
.
platform
===
'
gemini
'
)
}
return
props
.
groups
.
filter
((
g
)
=>
g
.
platform
===
props
.
platform
)
return
props
.
groups
.
filter
((
g
)
=>
g
.
platform
===
props
.
platform
)
})
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment