Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
f6ed3d14
Commit
f6ed3d14
authored
Jan 20, 2026
by
yangjianbo
Browse files
Merge branch 'test' into dev
parents
bdc426a7
84686753
Changes
19
Expand all
Show whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
f6ed3d14
...
...
@@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService
:=
service
.
ProvideDashboardAggregationService
(
dashboardAggregationRepository
,
timingWheelService
,
configConfig
)
dashboardHandler
:=
admin
.
NewDashboardHandler
(
dashboardService
,
dashboardAggregationService
)
accountRepository
:=
repository
.
NewAccountRepository
(
client
,
db
)
schedulerCache
:=
repository
.
NewSchedulerCache
(
redisClient
)
accountRepository
:=
repository
.
NewAccountRepository
(
client
,
db
,
schedulerCache
)
proxyRepository
:=
repository
.
NewProxyRepository
(
client
,
db
)
proxyExitInfoProber
:=
repository
.
NewProxyExitInfoProber
(
configConfig
)
proxyLatencyCache
:=
repository
.
NewProxyLatencyCache
(
redisClient
)
...
...
@@ -128,7 +129,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler
:=
admin
.
NewRedeemHandler
(
adminService
)
promoHandler
:=
admin
.
NewPromoHandler
(
promoService
)
opsRepository
:=
repository
.
NewOpsRepository
(
db
)
schedulerCache
:=
repository
.
NewSchedulerCache
(
redisClient
)
schedulerOutboxRepository
:=
repository
.
NewSchedulerOutboxRepository
(
db
)
schedulerSnapshotService
:=
service
.
ProvideSchedulerSnapshotService
(
schedulerCache
,
schedulerOutboxRepository
,
accountRepository
,
groupRepository
,
configConfig
)
pricingRemoteClient
:=
repository
.
ProvidePricingRemoteClient
(
configConfig
)
...
...
backend/internal/repository/account_repo.go
View file @
f6ed3d14
...
...
@@ -39,9 +39,15 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type
accountRepository
struct
{
client
*
dbent
.
Client
// Ent ORM 客户端
sql
sqlExecutor
// 原生 SQL 执行接口
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
// 确保粘性会话能及时感知账号不可用状态。
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache
service
.
SchedulerCache
}
type
tempUnschedSnapshot
struct
{
...
...
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func
NewAccountRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
)
service
.
AccountRepository
{
return
newAccountRepositoryWithSQL
(
client
,
sqlDB
)
func
NewAccountRepository
(
client
*
dbent
.
Client
,
sqlDB
*
sql
.
DB
,
schedulerCache
service
.
SchedulerCache
)
service
.
AccountRepository
{
return
newAccountRepositoryWithSQL
(
client
,
sqlDB
,
schedulerCache
)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func
newAccountRepositoryWithSQL
(
client
*
dbent
.
Client
,
sqlq
sqlExecutor
)
*
accountRepository
{
return
&
accountRepository
{
client
:
client
,
sql
:
sqlq
}
func
newAccountRepositoryWithSQL
(
client
*
dbent
.
Client
,
sqlq
sqlExecutor
,
schedulerCache
service
.
SchedulerCache
)
*
accountRepository
{
return
&
accountRepository
{
client
:
client
,
sql
:
sqlq
,
schedulerCache
:
schedulerCache
}
}
func
(
r
*
accountRepository
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
...
...
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
account
.
ID
,
nil
,
buildSchedulerGroupPayload
(
account
.
GroupIDs
));
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue account update failed: account=%d err=%v"
,
account
.
ID
,
err
)
}
if
account
.
Status
==
service
.
StatusError
||
account
.
Status
==
service
.
StatusDisabled
||
!
account
.
Schedulable
{
r
.
syncSchedulerAccountSnapshot
(
ctx
,
account
.
ID
)
}
return
nil
}
...
...
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue set error failed: account=%d err=%v"
,
id
,
err
)
}
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
return
nil
}
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
//
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
// when account status changes. Called when account is set to error, disabled,
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
// logic can promptly detect the latest account state and avoid using unavailable accounts.
func
(
r
*
accountRepository
)
syncSchedulerAccountSnapshot
(
ctx
context
.
Context
,
accountID
int64
)
{
if
r
==
nil
||
r
.
schedulerCache
==
nil
||
accountID
<=
0
{
return
}
account
,
err
:=
r
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
log
.
Printf
(
"[Scheduler] sync account snapshot read failed: id=%d err=%v"
,
accountID
,
err
)
return
}
if
err
:=
r
.
schedulerCache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[Scheduler] sync account snapshot write failed: id=%d err=%v"
,
accountID
,
err
)
}
}
func
(
r
*
accountRepository
)
AddToGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
priority
int
)
error
{
_
,
err
:=
r
.
client
.
AccountGroup
.
Create
()
.
SetAccountID
(
accountID
)
.
...
...
@@ -864,6 +896,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v"
,
id
,
err
)
}
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
return
nil
}
...
...
@@ -974,6 +1007,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v"
,
id
,
err
)
}
if
!
schedulable
{
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
}
return
nil
}
...
...
@@ -1128,6 +1164,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountBulkChanged
,
nil
,
nil
,
payload
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue bulk update failed: err=%v"
,
err
)
}
shouldSync
:=
false
if
updates
.
Status
!=
nil
&&
(
*
updates
.
Status
==
service
.
StatusError
||
*
updates
.
Status
==
service
.
StatusDisabled
)
{
shouldSync
=
true
}
if
updates
.
Schedulable
!=
nil
&&
!*
updates
.
Schedulable
{
shouldSync
=
true
}
if
shouldSync
{
for
_
,
id
:=
range
ids
{
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
}
}
}
return
rows
,
nil
}
...
...
backend/internal/repository/account_repo_integration_test.go
View file @
f6ed3d14
...
...
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo
*
accountRepository
}
type
schedulerCacheRecorder
struct
{
setAccounts
[]
*
service
.
Account
}
func
(
s
*
schedulerCacheRecorder
)
GetSnapshot
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
)
([]
*
service
.
Account
,
bool
,
error
)
{
return
nil
,
false
,
nil
}
func
(
s
*
schedulerCacheRecorder
)
SetSnapshot
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
,
accounts
[]
service
.
Account
)
error
{
return
nil
}
func
(
s
*
schedulerCacheRecorder
)
GetAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
s
*
schedulerCacheRecorder
)
SetAccount
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
s
.
setAccounts
=
append
(
s
.
setAccounts
,
account
)
return
nil
}
func
(
s
*
schedulerCacheRecorder
)
DeleteAccount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
(
s
*
schedulerCacheRecorder
)
UpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
s
*
schedulerCacheRecorder
)
TryLockBucket
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
schedulerCacheRecorder
)
ListBuckets
(
ctx
context
.
Context
)
([]
service
.
SchedulerBucket
,
error
)
{
return
nil
,
nil
}
func
(
s
*
schedulerCacheRecorder
)
GetOutboxWatermark
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
s
*
schedulerCacheRecorder
)
SetOutboxWatermark
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
s
*
AccountRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
s
.
client
=
tx
.
Client
()
s
.
repo
=
newAccountRepositoryWithSQL
(
s
.
client
,
tx
)
s
.
repo
=
newAccountRepositoryWithSQL
(
s
.
client
,
tx
,
nil
)
}
func
TestAccountRepoSuite
(
t
*
testing
.
T
)
{
...
...
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s
.
Require
()
.
Equal
(
"updated"
,
got
.
Name
)
}
func
(
s
*
AccountRepoSuite
)
TestUpdate_SyncSchedulerSnapshotOnDisabled
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"sync-update"
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
})
cacheRecorder
:=
&
schedulerCacheRecorder
{}
s
.
repo
.
schedulerCache
=
cacheRecorder
account
.
Status
=
service
.
StatusDisabled
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
account
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
s
.
Require
()
.
Len
(
cacheRecorder
.
setAccounts
,
1
)
s
.
Require
()
.
Equal
(
account
.
ID
,
cacheRecorder
.
setAccounts
[
0
]
.
ID
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
cacheRecorder
.
setAccounts
[
0
]
.
Status
)
}
func
(
s
*
AccountRepoSuite
)
TestDelete
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"to-delete"
})
...
...
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源
tx
:=
testEntTx
(
s
.
T
())
client
:=
tx
.
Client
()
repo
:=
newAccountRepositoryWithSQL
(
client
,
tx
)
repo
:=
newAccountRepositoryWithSQL
(
client
,
tx
,
nil
)
ctx
:=
context
.
Background
()
tt
.
setup
(
client
)
...
...
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func
(
s
*
AccountRepoSuite
)
TestSetSchedulable
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"acc-sched"
,
Schedulable
:
true
})
cacheRecorder
:=
&
schedulerCacheRecorder
{}
s
.
repo
.
schedulerCache
=
cacheRecorder
s
.
Require
()
.
NoError
(
s
.
repo
.
SetSchedulable
(
s
.
ctx
,
account
.
ID
,
false
))
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
account
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
False
(
got
.
Schedulable
)
s
.
Require
()
.
Len
(
cacheRecorder
.
setAccounts
,
1
)
s
.
Require
()
.
Equal
(
account
.
ID
,
cacheRecorder
.
setAccounts
[
0
]
.
ID
)
}
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_SyncSchedulerSnapshotOnDisabled
()
{
account1
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"bulk-1"
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
})
account2
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"bulk-2"
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
})
cacheRecorder
:=
&
schedulerCacheRecorder
{}
s
.
repo
.
schedulerCache
=
cacheRecorder
disabled
:=
service
.
StatusDisabled
rows
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
account1
.
ID
,
account2
.
ID
},
service
.
AccountBulkUpdate
{
Status
:
&
disabled
,
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
int64
(
2
),
rows
)
s
.
Require
()
.
Len
(
cacheRecorder
.
setAccounts
,
2
)
ids
:=
map
[
int64
]
struct
{}{}
for
_
,
acc
:=
range
cacheRecorder
.
setAccounts
{
ids
[
acc
.
ID
]
=
struct
{}{}
}
s
.
Require
()
.
Contains
(
ids
,
account1
.
ID
)
s
.
Require
()
.
Contains
(
ids
,
account2
.
ID
)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
...
...
backend/internal/repository/gateway_cache.go
View file @
f6ed3d14
...
...
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Expire
(
ctx
,
key
,
ttl
)
.
Err
()
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func
(
c
*
gatewayCache
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
backend/internal/repository/gateway_cache_integration_test.go
View file @
f6ed3d14
...
...
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require
.
NoError
(
s
.
T
(),
err
,
"RefreshSessionTTL on missing key should not error"
)
}
func
(
s
*
GatewayCacheSuite
)
TestDeleteSessionAccountID
()
{
sessionID
:=
"openai:s4"
accountID
:=
int64
(
102
)
groupID
:=
int64
(
1
)
sessionTTL
:=
1
*
time
.
Minute
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetSessionAccountID
(
s
.
ctx
,
groupID
,
sessionID
,
accountID
,
sessionTTL
),
"SetSessionAccountID"
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteSessionAccountID
(
s
.
ctx
,
groupID
,
sessionID
),
"DeleteSessionAccountID"
)
_
,
err
:=
s
.
cache
.
GetSessionAccountID
(
s
.
ctx
,
groupID
,
sessionID
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil after delete"
)
}
func
(
s
*
GatewayCacheSuite
)
TestGetSessionAccountID_CorruptedValue
()
{
sessionID
:=
"corrupted"
groupID
:=
int64
(
1
)
...
...
backend/internal/repository/gateway_routing_integration_test.go
View file @
f6ed3d14
...
...
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
s
.
client
=
tx
.
Client
()
s
.
accountRepo
=
newAccountRepositoryWithSQL
(
s
.
client
,
tx
)
s
.
accountRepo
=
newAccountRepositoryWithSQL
(
s
.
client
,
tx
,
nil
)
}
func
TestGatewayRoutingSuite
(
t
*
testing
.
T
)
{
...
...
backend/internal/repository/openai_oauth_service.go
View file @
f6ed3d14
...
...
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...
...
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
}
func
(
s
*
openaiOAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
:=
createOpenAIReqClient
(
proxyURL
)
client
:=
createOpenAIReqClient
(
s
.
tokenURL
,
proxyURL
)
if
redirectURI
==
""
{
redirectURI
=
openai
.
DefaultRedirectURI
...
...
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func
(
s
*
openaiOAuthService
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
:=
createOpenAIReqClient
(
proxyURL
)
client
:=
createOpenAIReqClient
(
s
.
tokenURL
,
proxyURL
)
formData
:=
url
.
Values
{}
formData
.
Set
(
"grant_type"
,
"refresh_token"
)
...
...
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return
&
tokenResp
,
nil
}
func
createOpenAIReqClient
(
proxyURL
string
)
*
req
.
Client
{
func
createOpenAIReqClient
(
tokenURL
,
proxyURL
string
)
*
req
.
Client
{
forceHTTP2
:=
false
if
parsedURL
,
err
:=
url
.
Parse
(
tokenURL
);
err
==
nil
{
forceHTTP2
=
strings
.
EqualFold
(
parsedURL
.
Scheme
,
"https"
)
}
return
getSharedReqClient
(
reqClientOptions
{
ProxyURL
:
proxyURL
,
Timeout
:
60
*
time
.
Second
,
Timeout
:
120
*
time
.
Second
,
ForceHTTP2
:
forceHTTP2
,
})
}
backend/internal/repository/openai_oauth_service_test.go
View file @
f6ed3d14
...
...
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require
.
ErrorContains
(
s
.
T
(),
err
,
"status 401"
)
}
func
TestNewOpenAIOAuthClient_DefaultTokenURL
(
t
*
testing
.
T
)
{
client
:=
NewOpenAIOAuthClient
()
svc
,
ok
:=
client
.
(
*
openaiOAuthService
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
TokenURL
,
svc
.
tokenURL
)
}
func
TestOpenAIOAuthServiceSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
OpenAIOAuthServiceSuite
))
}
backend/internal/repository/req_client_pool.go
View file @
f6ed3d14
...
...
@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL
string
// 代理 URL(支持 http/https/socks5)
Timeout
time
.
Duration
// 请求超时时间
Impersonate
bool
// 是否模拟 Chrome 浏览器指纹
ForceHTTP2
bool
// 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
...
...
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client
:=
req
.
C
()
.
SetTimeout
(
opts
.
Timeout
)
if
opts
.
ForceHTTP2
{
client
=
client
.
EnableForceHTTP2
()
}
if
opts
.
Impersonate
{
client
=
client
.
ImpersonateChrome
()
}
...
...
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func
buildReqClientKey
(
opts
reqClientOptions
)
string
{
return
fmt
.
Sprintf
(
"%s|%s|%t"
,
return
fmt
.
Sprintf
(
"%s|%s|%t
|%t
"
,
strings
.
TrimSpace
(
opts
.
ProxyURL
),
opts
.
Timeout
.
String
(),
opts
.
Impersonate
,
opts
.
ForceHTTP2
,
)
}
backend/internal/repository/req_client_pool_test.go
0 → 100644
View file @
f6ed3d14
package
repository
import
(
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func
forceHTTPVersion
(
t
*
testing
.
T
,
client
*
req
.
Client
)
string
{
t
.
Helper
()
transport
:=
client
.
GetTransport
()
field
:=
reflect
.
ValueOf
(
transport
)
.
Elem
()
.
FieldByName
(
"forceHttpVersion"
)
require
.
True
(
t
,
field
.
IsValid
(),
"forceHttpVersion field not found"
)
require
.
True
(
t
,
field
.
CanAddr
(),
"forceHttpVersion field not addressable"
)
return
reflect
.
NewAt
(
field
.
Type
(),
unsafe
.
Pointer
(
field
.
UnsafeAddr
()))
.
Elem
()
.
String
()
}
func
TestGetSharedReqClient_ForceHTTP2SeparatesCache
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
base
:=
reqClientOptions
{
ProxyURL
:
"http://proxy.local:8080"
,
Timeout
:
time
.
Second
,
}
clientDefault
:=
getSharedReqClient
(
base
)
force
:=
base
force
.
ForceHTTP2
=
true
clientForce
:=
getSharedReqClient
(
force
)
require
.
NotSame
(
t
,
clientDefault
,
clientForce
)
require
.
NotEqual
(
t
,
buildReqClientKey
(
base
),
buildReqClientKey
(
force
))
}
func
TestGetSharedReqClient_ReuseCachedClient
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
"http://proxy.local:8080"
,
Timeout
:
2
*
time
.
Second
,
}
first
:=
getSharedReqClient
(
opts
)
second
:=
getSharedReqClient
(
opts
)
require
.
Same
(
t
,
first
,
second
)
}
func
TestGetSharedReqClient_IgnoresNonClientCache
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
" http://proxy.local:8080 "
,
Timeout
:
3
*
time
.
Second
,
}
key
:=
buildReqClientKey
(
opts
)
sharedReqClients
.
Store
(
key
,
"invalid"
)
client
:=
getSharedReqClient
(
opts
)
require
.
NotNil
(
t
,
client
)
loaded
,
ok
:=
sharedReqClients
.
Load
(
key
)
require
.
True
(
t
,
ok
)
require
.
IsType
(
t
,
"invalid"
,
loaded
)
}
func
TestGetSharedReqClient_ImpersonateAndProxy
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
opts
:=
reqClientOptions
{
ProxyURL
:
" http://proxy.local:8080 "
,
Timeout
:
4
*
time
.
Second
,
Impersonate
:
true
,
}
client
:=
getSharedReqClient
(
opts
)
require
.
NotNil
(
t
,
client
)
require
.
Equal
(
t
,
"http://proxy.local:8080|4s|true|false"
,
buildReqClientKey
(
opts
))
}
func
TestCreateOpenAIReqClient_ForceHTTP2Enabled
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createOpenAIReqClient
(
"https://auth.openai.com/oauth/token"
,
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
"2"
,
forceHTTPVersion
(
t
,
client
))
}
func
TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createOpenAIReqClient
(
"http://localhost/oauth/token"
,
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
""
,
forceHTTPVersion
(
t
,
client
))
}
func
TestCreateOpenAIReqClient_Timeout120Seconds
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createOpenAIReqClient
(
"https://auth.openai.com/oauth/token"
,
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
120
*
time
.
Second
,
client
.
GetClient
()
.
Timeout
)
}
func
TestCreateGeminiReqClient_ForceHTTP2Disabled
(
t
*
testing
.
T
)
{
sharedReqClients
=
sync
.
Map
{}
client
:=
createGeminiReqClient
(
"http://proxy.local:8080"
)
require
.
Equal
(
t
,
""
,
forceHTTPVersion
(
t
,
client
))
}
backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
View file @
f6ed3d14
...
...
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_
,
_
=
integrationDB
.
ExecContext
(
ctx
,
"TRUNCATE scheduler_outbox"
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
)
accountRepo
:=
newAccountRepositoryWithSQL
(
client
,
integrationDB
,
nil
)
outboxRepo
:=
NewSchedulerOutboxRepository
(
integrationDB
)
cache
:=
NewSchedulerCache
(
rdb
)
...
...
backend/internal/service/gateway_multiplatform_test.go
View file @
f6ed3d14
This diff is collapsed.
Click to expand it.
backend/internal/service/gateway_service.go
View file @
f6ed3d14
...
...
@@ -97,11 +97,24 @@ var allowedHeaders = map[string]bool{
"content-type"
:
true
,
}
// GatewayCache defines cache operations for gateway service
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type
GatewayCache
interface
{
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
ttl
time
.
Duration
)
error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...
...
@@ -112,6 +125,28 @@ func derefGroupID(groupID *int64) int64 {
return
*
groupID
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func
shouldClearStickySession
(
account
*
Account
)
bool
{
if
account
==
nil
{
return
false
}
if
account
.
Status
==
StatusError
||
account
.
Status
==
StatusDisabled
||
!
account
.
Schedulable
{
return
true
}
if
account
.
TempUnschedulableUntil
!=
nil
&&
time
.
Now
()
.
Before
(
*
account
.
TempUnschedulableUntil
)
{
return
true
}
return
false
}
type
AccountWaitPlan
struct
{
AccountID
int64
MaxConcurrency
int
...
...
@@ -620,6 +655,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
}
else
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
}
}
...
...
@@ -720,7 +757,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
ok
:=
accountByID
[
accountID
]
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
if
ok
{
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
&&
...
...
@@ -728,6 +772,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
// Session count limit check
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
...
...
@@ -755,6 +800,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates
:=
make
([]
*
Account
,
0
,
len
(
accounts
))
...
...
@@ -1278,7 +1324,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
...
...
@@ -1290,6 +1341,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
...
...
@@ -1375,7 +1427,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
...
...
@@ -1384,6 +1441,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
}
}
// 2. 获取可调度账号列表(单平台)
if
!
accountsLoaded
{
...
...
@@ -1479,7 +1537,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
...
...
@@ -1493,6 +1556,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
}
}
// 2) Select an account from the routed candidates.
var
err
error
...
...
@@ -1578,7 +1642,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
...
...
@@ -1589,6 +1658,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
}
}
// 2. 获取可调度账号列表
if
!
accountsLoaded
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
f6ed3d14
...
...
@@ -82,145 +82,276 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func
(
s
*
GeminiMessagesCompatService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform
,
useMixedScheduling
,
hasForcePlatform
,
err
:=
s
.
resolvePlatformAndSchedulingMode
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
cacheKey
:=
"gemini:"
+
sessionHash
// 2. 尝试粘性会话命中
// Try sticky session hit
if
account
:=
s
.
tryStickySessionHit
(
ctx
,
groupID
,
sessionHash
,
cacheKey
,
requestedModel
,
excludedIDs
,
platform
,
useMixedScheduling
);
account
!=
nil
{
return
account
,
nil
}
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
groupID
!=
nil
&&
hasForcePlatform
{
accounts
,
err
=
s
.
listSchedulableAccountsOnce
(
ctx
,
nil
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected
:=
s
.
selectBestGeminiAccount
(
ctx
,
accounts
,
requestedModel
,
excludedIDs
,
platform
,
useMixedScheduling
)
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available Gemini accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
geminiStickySessionTTL
)
}
return
selected
,
nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func
(
s
*
GeminiMessagesCompatService
)
resolvePlatformAndSchedulingMode
(
ctx
context
.
Context
,
groupID
*
int64
)
(
platform
string
,
useMixedScheduling
bool
,
hasForcePlatform
bool
,
err
error
)
{
// 优先检查 context 中的强制平台(/antigravity 路由)
var
platform
string
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
return
forcePlatform
,
false
,
true
,
nil
}
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
var
group
*
Group
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
group
=
ctxGroup
}
else
{
var
err
error
group
,
err
=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
""
,
false
,
false
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
platform
=
group
.
Platform
}
else
{
// 无分组时只使用原生 gemini 平台
platform
=
PlatformGemini
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
return
group
.
Platform
,
group
.
Platform
==
PlatformGemini
,
false
,
nil
}
//
gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling
:=
platform
==
PlatformGemini
&&
!
hasForcePlatform
//
无分组时只使用原生 gemini 平台
return
PlatformGemini
,
true
,
false
,
nil
}
cacheKey
:=
"gemini:"
+
sessionHash
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func
(
s
*
GeminiMessagesCompatService
)
tryStickySessionHit
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
,
cacheKey
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
,
useMixedScheduling
bool
,
)
*
Account
{
if
sessionHash
==
""
{
return
nil
}
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
valid
:=
false
if
account
.
Platform
==
platform
{
valid
=
true
}
else
if
useMixedScheduling
&&
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
()
{
valid
=
true
if
err
!=
nil
||
accountID
<=
0
{
return
nil
}
if
valid
{
usable
:=
true
if
s
.
rateLimitService
!=
nil
&&
requestedModel
!=
""
{
ok
,
err
:=
s
.
rateLimitService
.
PreCheckUsage
(
ctx
,
account
,
requestedModel
)
if
_
,
excluded
:=
excludedIDs
[
accountID
];
excluded
{
return
nil
}
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
!=
nil
{
log
.
Printf
(
"[Gemini PreCheck] Account %d precheck error: %v"
,
account
.
ID
,
err
)
return
nil
}
if
!
ok
{
usable
=
false
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if
shouldClearStickySession
(
account
)
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
return
nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if
!
s
.
isAccountUsableForRequest
(
ctx
,
account
,
requestedModel
,
platform
,
useMixedScheduling
)
{
return
nil
}
if
usable
{
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
geminiStickySessionTTL
)
return
account
,
nil
return
account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func
(
s
*
GeminiMessagesCompatService
)
isAccountUsableForRequest
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
,
platform
string
,
useMixedScheduling
bool
,
)
bool
{
// 检查模型调度能力
// Check model scheduling capability
if
!
account
.
IsSchedulableForModel
(
requestedModel
)
{
return
false
}
// 检查模型支持
// Check model support
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
account
,
requestedModel
)
{
return
false
}
// 检查平台匹配
// Check platform matching
if
!
s
.
isAccountValidForPlatform
(
account
,
platform
,
useMixedScheduling
)
{
return
false
}
// 速率限制预检
// Rate limit precheck
if
!
s
.
passesRateLimitPreCheck
(
ctx
,
account
,
requestedModel
)
{
return
false
}
return
true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func
(
s
*
GeminiMessagesCompatService
)
isAccountValidForPlatform
(
account
*
Account
,
platform
string
,
useMixedScheduling
bool
)
bool
{
if
account
.
Platform
==
platform
{
return
true
}
if
useMixedScheduling
&&
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
()
{
return
true
}
return
false
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func
(
s
*
GeminiMessagesCompatService
)
passesRateLimitPreCheck
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
bool
{
if
s
.
rateLimitService
==
nil
||
requestedModel
==
""
{
return
true
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
groupID
!=
nil
&&
hasForcePlatform
{
accounts
,
err
=
s
.
listSchedulableAccountsOnce
(
ctx
,
nil
,
platform
,
hasForcePlatform
)
ok
,
err
:=
s
.
rateLimitService
.
PreCheckUsage
(
ctx
,
account
,
requestedModel
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
log
.
Printf
(
"[Gemini PreCheck] Account %d precheck error: %v"
,
account
.
ID
,
err
)
}
return
ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func
(
s
*
GeminiMessagesCompatService
)
selectBestGeminiAccount
(
ctx
context
.
Context
,
accounts
[]
Account
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
,
useMixedScheduling
bool
,
)
*
Account
{
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
// 跳过被排除的账号
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
// 非混合调度模式(antigravity 分组):不需要过滤
if
useMixedScheduling
&&
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
s
.
rateLimitService
!=
nil
&&
requestedModel
!=
""
{
ok
,
err
:=
s
.
rateLimitService
.
PreCheckUsage
(
ctx
,
acc
,
requestedModel
)
if
err
!=
nil
{
log
.
Printf
(
"[Gemini PreCheck] Account %d precheck error: %v"
,
acc
.
ID
,
err
)
}
if
!
ok
{
// 检查账号是否可用于当前请求
if
!
s
.
isAccountUsableForRequest
(
ctx
,
acc
,
requestedModel
,
platform
,
useMixedScheduling
)
{
continue
}
}
// 选择最佳账号
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if
acc
.
Type
==
AccountTypeOAuth
&&
selected
.
Type
!=
AccountTypeOAuth
{
selected
=
acc
}
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
if
s
.
isBetterGeminiAccount
(
acc
,
selected
)
{
selected
=
acc
}
}
}
}
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available Gemini accounts supporting model: %s"
,
requestedModel
)
return
selected
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func
(
s
*
GeminiMessagesCompatService
)
isBetterGeminiAccount
(
candidate
,
current
*
Account
)
bool
{
// 优先级更高(数值更小)
if
candidate
.
Priority
<
current
.
Priority
{
return
true
}
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
if
candidate
.
Priority
>
current
.
Priority
{
return
false
}
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
geminiStickySessionTTL
)
// 同优先级,比较最后使用时间
switch
{
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
!=
nil
:
// candidate 从未使用,优先
return
true
case
candidate
.
LastUsedAt
!=
nil
&&
current
.
LastUsedAt
==
nil
:
// current 从未使用,保持
return
false
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
==
nil
:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return
candidate
.
Type
==
AccountTypeOAuth
&&
current
.
Type
!=
AccountTypeOAuth
default
:
// 都使用过,选择最久未使用的
return
candidate
.
LastUsedAt
.
Before
(
*
current
.
LastUsedAt
)
}
return
selected
,
nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
f6ed3d14
...
...
@@ -17,6 +17,8 @@ import (
type
mockAccountRepoForGemini
struct
{
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
listByGroupFunc
func
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
listByPlatformFunc
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
}
func
(
m
*
mockAccountRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
...
...
@@ -104,6 +106,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
if
m
.
listByPlatformFunc
!=
nil
{
return
m
.
listByPlatformFunc
(
ctx
,
platforms
)
}
var
result
[]
Account
platformSet
:=
make
(
map
[
string
]
bool
)
for
_
,
p
:=
range
platforms
{
...
...
@@ -117,6 +122,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return
result
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
if
m
.
listByGroupFunc
!=
nil
{
return
m
.
listByGroupFunc
(
ctx
,
groupID
,
platforms
)
}
return
m
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
m
*
mockAccountRepoForGemini
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
...
...
@@ -212,6 +220,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type
mockGatewayCacheForGemini
struct
{
sessionBindings
map
[
string
]
int64
deletedSessions
map
[
string
]
int
}
func
(
m
*
mockGatewayCacheForGemini
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
...
...
@@ -233,6 +242,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return
nil
}
func
(
m
*
mockGatewayCacheForGemini
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
if
m
.
sessionBindings
==
nil
{
return
nil
}
if
m
.
deletedSessions
==
nil
{
m
.
deletedSessions
=
make
(
map
[
string
]
int
)
}
m
.
deletedSessions
[
sessionHash
]
++
delete
(
m
.
sessionBindings
,
sessionHash
)
return
nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
@@ -523,6 +544,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话未命中,应按优先级选择"
)
})
t
.
Run
(
"粘性会话不可调度-清理并回退选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusDisabled
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-123"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
1
,
cache
.
deletedSessions
[
"gemini:session-123"
])
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"gemini:session-123"
])
})
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
9
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
PlatformAntigravity
)
repo
:=
&
mockAccountRepoForGemini
{
listByGroupFunc
:
func
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
nil
,
nil
},
listByPlatformFunc
:
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
nil
},
accountsByID
:
map
[
int64
]
*
Account
{
1
:
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-1.0-pro"
:
"gemini-1.0-pro"
}},
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"supporting model"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-999"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-999"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
excluded
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
listByPlatformFunc
:
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
nil
,
errors
.
New
(
"query failed"
)
},
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"query accounts failed"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
oldTime
:=
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
newTime
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
&
newTime
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
&
oldTime
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
...
...
backend/internal/service/openai_gateway_service.go
View file @
f6ed3d14
...
...
@@ -180,81 +180,164 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func
(
s
*
OpenAIGatewayService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 1. Check sticky session
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// Refresh sticky session TTL
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
openaiStickySessionTTL
)
cacheKey
:=
"openai:"
+
sessionHash
// 1. 尝试粘性会话命中
// Try sticky session hit
if
account
:=
s
.
tryStickySessionHit
(
ctx
,
groupID
,
sessionHash
,
cacheKey
,
requestedModel
,
excludedIDs
);
account
!=
nil
{
return
account
,
nil
}
// 2. 获取可调度的 OpenAI 账号
// Get schedulable OpenAI accounts
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected
:=
s
.
selectBestAccount
(
accounts
,
requestedModel
,
excludedIDs
)
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available OpenAI accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available OpenAI accounts"
)
}
// 4. 设置粘性会话绑定
// Set sticky session binding
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
openaiStickySessionTTL
)
}
// 2. Get schedulable OpenAI accounts
accounts
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
)
return
selected
,
nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func
(
s
*
OpenAIGatewayService
)
tryStickySessionHit
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
,
cacheKey
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
*
Account
{
if
sessionHash
==
""
{
return
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
if
err
!=
nil
||
accountID
<=
0
{
return
nil
}
if
_
,
excluded
:=
excludedIDs
[
accountID
];
excluded
{
return
nil
}
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
return
nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if
shouldClearStickySession
(
account
)
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
return
nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if
!
account
.
IsSchedulable
()
||
!
account
.
IsOpenAI
()
{
return
nil
}
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
openaiStickySessionTTL
)
return
account
}
// 3. Select by priority + LRU
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func
(
s
*
OpenAIGatewayService
)
selectBestAccount
(
accounts
[]
Account
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
*
Account
{
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
// 跳过被排除的账号
// Skip excluded accounts
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
if
!
acc
.
IsSchedulable
()
||
!
acc
.
IsOpenAI
()
{
continue
}
// 检查模型支持
// Check model support
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if
selected
==
nil
{
selected
=
acc
continue
}
// Lower priority value means higher priority
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (both never used)
default
:
// Same priority, select least recently used
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
if
s
.
isBetterAccount
(
acc
,
selected
)
{
selected
=
acc
}
}
}
}
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available OpenAI accounts supporting model: %s"
,
requestedModel
)
return
selected
}
// isBetterAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
//
// isBetterAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
func
(
s
*
OpenAIGatewayService
)
isBetterAccount
(
candidate
,
current
*
Account
)
bool
{
// 优先级更高(数值更小)
// Higher priority (lower value)
if
candidate
.
Priority
<
current
.
Priority
{
return
true
}
return
nil
,
errors
.
New
(
"no available OpenAI accounts"
)
if
candidate
.
Priority
>
current
.
Priority
{
return
false
}
// 4. Set sticky session
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
,
selected
.
ID
,
openaiStickySessionTTL
)
// 同优先级,比较最后使用时间
// Same priority, compare last used time
switch
{
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
!=
nil
:
// candidate 从未使用,优先
return
true
case
candidate
.
LastUsedAt
!=
nil
&&
current
.
LastUsedAt
==
nil
:
// current 从未使用,保持
return
false
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
==
nil
:
// 都未使用,保持
return
false
default
:
// 都使用过,选择最久未使用的
return
candidate
.
LastUsedAt
.
Before
(
*
current
.
LastUsedAt
)
}
return
selected
,
nil
}
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
...
...
@@ -325,7 +408,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
"openai:"
+
sessionHash
)
}
if
!
clearSticky
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
...
...
@@ -352,6 +440,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
}
}
// ============ Layer 2: Load-aware selection ============
candidates
:=
make
([]
*
Account
,
0
,
len
(
accounts
))
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
f6ed3d14
This diff is collapsed.
Click to expand it.
backend/internal/service/sticky_session_test.go
0 → 100644
View file @
f6ed3d14
//go:build unit
// Package service 提供 API 网关核心服务。
// 本文件包含 shouldClearStickySession 函数的单元测试,
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
//
// This file contains unit tests for the shouldClearStickySession function,
// verifying correct sticky session clearing behavior under various account states.
package
service
import
(
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
// 验证在以下情况下是否正确判断需要清理粘性会话:
// - nil 账号:不清理(返回 false)
// - 状态为错误或禁用:清理
// - 不可调度:清理
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
func
TestShouldClearStickySession
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
future
:=
now
.
Add
(
1
*
time
.
Hour
)
past
:=
now
.
Add
(
-
1
*
time
.
Hour
)
tests
:=
[]
struct
{
name
string
account
*
Account
want
bool
}{
{
name
:
"nil account"
,
account
:
nil
,
want
:
false
},
{
name
:
"status error"
,
account
:
&
Account
{
Status
:
StatusError
,
Schedulable
:
true
},
want
:
true
},
{
name
:
"status disabled"
,
account
:
&
Account
{
Status
:
StatusDisabled
,
Schedulable
:
true
},
want
:
true
},
{
name
:
"schedulable false"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
false
},
want
:
true
},
{
name
:
"temp unschedulable"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
,
TempUnschedulableUntil
:
&
future
},
want
:
true
},
{
name
:
"temp unschedulable expired"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
,
TempUnschedulableUntil
:
&
past
},
want
:
false
},
{
name
:
"active schedulable"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
},
want
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
shouldClearStickySession
(
tt
.
account
))
})
}
}
backend/internal/service/usage_cleanup_service_test.go
View file @
f6ed3d14
...
...
@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
deleteCalls
,
3
)
require
.
Equal
(
t
,
2
,
repo
.
deleteCalls
[
0
]
.
limit
)
require
.
True
(
t
,
repo
.
deleteCalls
[
0
]
.
filters
.
StartTime
.
Equal
(
start
))
require
.
True
(
t
,
repo
.
deleteCalls
[
0
]
.
filters
.
EndTime
.
Equal
(
end
))
require
.
Len
(
t
,
repo
.
markSucceeded
,
1
)
require
.
Empty
(
t
,
repo
.
markFailed
)
require
.
Equal
(
t
,
int64
(
5
),
repo
.
markSucceeded
[
0
]
.
taskID
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment