Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c7137dff
Unverified
Commit
c7137dff
authored
Mar 24, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 24, 2026
Browse files
Merge pull request #1218 from LvyuanW/openai-runtime-recheck
fix(openai): prevent rescheduling rate-limited accounts
parents
5a3375ce
fef9259a
Changes
17
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/account_repo.go
View file @
c7137dff
...
...
@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return
nil
}
func
(
r
*
accountRepository
)
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
_
,
err
:=
r
.
client
.
Account
.
UpdateOneID
(
id
)
.
SetCredentials
(
normalizeJSONMap
(
credentials
))
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrAccountNotFound
,
nil
)
}
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
return
nil
}
func
(
r
*
accountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
groupIDs
,
err
:=
r
.
loadAccountGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
...
...
backend/internal/service/account_credentials_persistence.go
0 → 100644
View file @
c7137dff
package
service
import
"context"
type
accountCredentialsUpdater
interface
{
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
}
func
persistAccountCredentials
(
ctx
context
.
Context
,
repo
AccountRepository
,
account
*
Account
,
credentials
map
[
string
]
any
)
error
{
if
repo
==
nil
||
account
==
nil
{
return
nil
}
account
.
Credentials
=
cloneCredentials
(
credentials
)
if
updater
,
ok
:=
any
(
repo
)
.
(
accountCredentialsUpdater
);
ok
{
return
updater
.
UpdateCredentials
(
ctx
,
account
.
ID
,
account
.
Credentials
)
}
return
repo
.
Update
(
ctx
,
account
)
}
func
cloneCredentials
(
in
map
[
string
]
any
)
map
[
string
]
any
{
if
in
==
nil
{
return
map
[
string
]
any
{}
}
out
:=
make
(
map
[
string
]
any
,
len
(
in
))
for
k
,
v
:=
range
in
{
out
[
k
]
=
v
}
return
out
}
backend/internal/service/antigravity_token_provider.go
View file @
c7137dff
...
...
@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
p
.
markBackfillAttempted
(
account
.
ID
)
if
projectID
,
err
:=
p
.
antigravityOAuthService
.
FillProjectID
(
ctx
,
account
,
accessToken
);
err
==
nil
&&
projectID
!=
""
{
account
.
Credentials
[
"project_id"
]
=
projectID
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
if
updateErr
:=
p
ersistAccountCredentials
(
ctx
,
p
.
accountRepo
,
account
,
account
.
Credentials
);
updateErr
!=
nil
{
slog
.
Warn
(
"antigravity_project_id_backfill_persist_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
,
...
...
backend/internal/service/crs_sync_service.go
View file @
c7137dff
...
...
@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after creation
if
targetType
==
AccountTypeOAuth
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
account
.
Credentials
=
refreshedCreds
_
=
s
.
accountRepo
.
Update
(
ctx
,
account
)
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
refreshedCreds
)
}
}
item
.
Action
=
"created"
...
...
@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update
if
targetType
==
AccountTypeOAuth
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
existing
.
Credentials
=
refreshedCreds
_
=
s
.
accountRepo
.
Update
(
ctx
,
existing
)
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
existing
,
refreshedCreds
)
}
}
...
...
@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
// 🔄 Refresh OAuth token after creation
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
account
.
Credentials
=
refreshedCreds
_
=
s
.
accountRepo
.
Update
(
ctx
,
account
)
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
refreshedCreds
)
}
item
.
Action
=
"created"
result
.
Created
++
...
...
@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
existing
.
Credentials
=
refreshedCreds
_
=
s
.
accountRepo
.
Update
(
ctx
,
existing
)
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
existing
,
refreshedCreds
)
}
item
.
Action
=
"updated"
...
...
@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
account
.
Credentials
=
refreshedCreds
_
=
s
.
accountRepo
.
Update
(
ctx
,
account
)
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
refreshedCreds
)
}
item
.
Action
=
"created"
result
.
Created
++
...
...
@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
existing
.
Credentials
=
refreshedCreds
_
=
s
.
accountRepo
.
Update
(
ctx
,
existing
)
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
existing
,
refreshedCreds
)
}
item
.
Action
=
"updated"
...
...
backend/internal/service/gemini_token_provider.go
View file @
c7137dff
...
...
@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
tierID
!=
""
{
account
.
Credentials
[
"tier_id"
]
=
tierID
}
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
_
=
p
ersistAccountCredentials
(
ctx
,
p
.
accountRepo
,
account
,
account
.
Credentials
)
}
}
...
...
backend/internal/service/oauth_refresh_api.go
View file @
c7137dff
...
...
@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB
if
newCredentials
!=
nil
{
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
freshAccount
.
Credentials
=
newCredentials
if
updateErr
:=
api
.
accountRepo
.
Update
(
ctx
,
freshAccount
);
updateErr
!=
nil
{
if
updateErr
:=
persistAccountCredentials
(
ctx
,
api
.
accountRepo
,
freshAccount
,
newCredentials
);
updateErr
!=
nil
{
slog
.
Error
(
"oauth_refresh_update_failed"
,
"account_id"
,
freshAccount
.
ID
,
"error"
,
updateErr
,
...
...
backend/internal/service/oauth_refresh_api_test.go
View file @
c7137dff
...
...
@@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type
refreshAPIAccountRepo
struct
{
mockAccountRepoForGemini
account
*
Account
// returned by GetByID
getByIDErr
error
updateErr
error
updateCalls
int
account
*
Account
// returned by GetByID
getByIDErr
error
updateErr
error
updateCalls
int
updateCredentialsCalls
int
}
func
(
r
*
refreshAPIAccountRepo
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
Account
,
error
)
{
...
...
@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return
r
.
updateErr
}
func
(
r
*
refreshAPIAccountRepo
)
UpdateCredentials
(
_
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCalls
++
r
.
updateCredentialsCalls
++
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
if
r
.
account
==
nil
||
r
.
account
.
ID
!=
id
{
r
.
account
=
&
Account
{
ID
:
id
}
}
r
.
account
.
Credentials
=
cloneCredentials
(
credentials
)
return
nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type
refreshAPIExecutorStub
struct
{
needsRefresh
bool
...
...
@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require
.
Equal
(
t
,
"new-token"
,
result
.
NewCredentials
[
"access_token"
])
require
.
NotNil
(
t
,
result
.
NewCredentials
[
"_token_version"
])
// version stamp set
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB updated
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
}
func
TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState
(
t
*
testing
.
T
)
{
resetAt
:=
time
.
Now
()
.
Add
(
45
*
time
.
Minute
)
account
:=
&
Account
{
ID
:
11
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
RateLimitResetAt
:
&
resetAt
,
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"safe-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
NotNil
(
t
,
repo
.
account
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
,
*
repo
.
account
.
RateLimitResetAt
,
time
.
Second
)
}
func
TestRefreshIfNeeded_LockHeld
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformAnthropic
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
...
...
@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid_grant"
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock still released via defer
}
...
...
@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result
:=
MergeCredentials
(
old
,
new
)
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
}
// ========== BuildClaudeAccountCredentials tests ==========
...
...
backend/internal/service/openai_account_scheduler.go
View file @
c7137dff
...
...
@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
return
nil
,
nil
}
account
=
s
.
service
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
req
.
RequestedModel
)
if
account
==
nil
{
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
return
nil
,
nil
}
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
acquireErr
==
nil
&&
result
.
Acquired
{
...
...
@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
continue
}
fresh
=
s
.
service
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
fresh
,
req
.
RequestedModel
)
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
continue
}
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
if
acquireErr
!=
nil
{
return
nil
,
len
(
candidates
),
topK
,
loadSkew
,
acquireErr
...
...
backend/internal/service/openai_account_scheduler_test.go
View file @
c7137dff
...
...
@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require
.
Equal
(
t
,
int64
(
32002
),
account
.
ID
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10103
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
staleSticky
:=
&
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
}
staleBackup
:=
&
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbSticky
:=
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbBackup
:=
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_db_runtime_recheck"
:
33001
}}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
staleSticky
,
staleBackup
},
accountsByID
:
map
[
int64
]
*
Account
{
33001
:
staleSticky
,
33002
:
staleBackup
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbSticky
,
dbBackup
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
""
,
"session_hash_db_runtime_recheck"
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
selection
)
require
.
NotNil
(
t
,
selection
.
Account
)
require
.
Equal
(
t
,
int64
(
33002
),
selection
.
Account
.
ID
)
require
.
Equal
(
t
,
openAIAccountScheduleLayerLoadBalance
,
decision
.
Layer
)
}
func
TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10104
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
stalePrimary
:=
&
Account
{
ID
:
34001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
}
staleSecondary
:=
&
Account
{
ID
:
34002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbPrimary
:=
Account
{
ID
:
34001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbSecondary
:=
Account
{
ID
:
34002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
stalePrimary
,
staleSecondary
},
accountsByID
:
map
[
int64
]
*
Account
{
34001
:
stalePrimary
,
34002
:
staleSecondary
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbPrimary
,
dbSecondary
}},
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
int64
(
34002
),
account
.
ID
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
9
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
c7137dff
...
...
@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
return
nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
...
...
@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if
fresh
==
nil
{
continue
}
fresh
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
fresh
,
requestedModel
)
if
fresh
==
nil
{
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
...
...
@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
if
!
clearSticky
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
}
else
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
accountID
,
MaxConcurrency
:
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
accountID
,
MaxConcurrency
:
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
}
}
...
...
@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return
fresh
}
func
(
s
*
OpenAIGatewayService
)
recheckSelectedOpenAIAccountFromDB
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
*
Account
{
if
account
==
nil
{
return
nil
}
if
s
.
schedulerSnapshot
==
nil
||
s
.
accountRepo
==
nil
{
return
account
}
latest
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
latest
==
nil
{
return
nil
}
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
latest
,
time
.
Now
())
if
!
latest
.
IsSchedulable
()
||
!
latest
.
IsOpenAI
()
{
return
nil
}
if
requestedModel
!=
""
&&
!
latest
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
return
latest
}
func
(
s
*
OpenAIGatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
var
(
account
*
Account
...
...
backend/internal/service/openai_ws_account_sticky_test.go
View file @
c7137dff
...
...
@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require
.
Zero
(
t
,
boundAccountID
)
}
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
24
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
staleAccount
:=
&
Account
{
ID
:
13
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
dbAccount
:=
Account
{
ID
:
13
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
RateLimitResetAt
:
&
rateLimitedUntil
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
cache
:=
&
stubGatewayCache
{}
store
:=
NewOpenAIWSStateStore
(
cache
)
cfg
:=
newOpenAIWSV2TestConfig
()
snapshotCache
:=
&
openAISnapshotCacheStub
{
accountsByID
:
map
[
int64
]
*
Account
{
dbAccount
.
ID
:
staleAccount
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbAccount
}},
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
openaiWSStateStore
:
store
,
schedulerSnapshot
:
&
SchedulerSnapshotService
{
cache
:
snapshotCache
},
}
require
.
NoError
(
t
,
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_prev_db_rl"
,
dbAccount
.
ID
,
time
.
Hour
))
selection
,
err
:=
svc
.
SelectAccountByPreviousResponseID
(
ctx
,
&
groupID
,
"resp_prev_db_rl"
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
selection
,
"DB 中已限流的账号不应继续命中 previous_response_id 粘连"
)
boundAccountID
,
getErr
:=
store
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_prev_db_rl"
)
require
.
NoError
(
t
,
getErr
)
require
.
Zero
(
t
,
boundAccountID
)
}
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
23
)
...
...
backend/internal/service/openai_ws_forwarder.go
View file @
c7137dff
...
...
@@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
,
nil
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
store
.
DeleteResponseAccount
(
ctx
,
derefGroupID
(
groupID
),
responseID
)
return
nil
,
nil
}
result
,
acquireErr
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
acquireErr
==
nil
&&
result
.
Acquired
{
...
...
backend/internal/service/ratelimit_service.go
View file @
c7137dff
...
...
@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
account
.
Credentials
[
"expires_at"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
account
.
Credentials
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_force_refresh_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
else
{
slog
.
Info
(
"oauth_401_force_refresh_set"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
)
...
...
backend/internal/service/ratelimit_service_401_test.go
View file @
c7137dff
...
...
@@ -15,9 +15,11 @@ import (
type
rateLimitAccountRepoStub
struct
{
mockAccountRepoForGemini
setErrorCalls
int
tempCalls
int
lastErrorMsg
string
setErrorCalls
int
tempCalls
int
updateCredentialsCalls
int
lastCredentials
map
[
string
]
any
lastErrorMsg
string
}
func
(
r
*
rateLimitAccountRepoStub
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
...
...
@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return
nil
}
func
(
r
*
rateLimitAccountRepoStub
)
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCredentialsCalls
++
r
.
lastCredentials
=
cloneCredentials
(
credentials
)
return
nil
}
type
tokenCacheInvalidatorRecorder
struct
{
accounts
[]
*
Account
err
error
...
...
@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
}
...
...
@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Empty
(
t
,
invalidator
.
accounts
)
}
func
TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
NotEmpty
(
t
,
repo
.
lastCredentials
[
"expires_at"
])
}
backend/internal/service/sora_sdk_client.go
View file @
c7137dff
...
...
@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
}
if
c
.
accountRepo
!=
nil
{
if
err
:=
c
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
&&
c
.
debugEnabled
()
{
if
err
:=
persistAccountCredentials
(
ctx
,
c
.
accountRepo
,
account
,
account
.
Credentials
);
err
!=
nil
&&
c
.
debugEnabled
()
{
c
.
debugLogf
(
"persist_recovered_token_failed account_id=%d err=%s"
,
account
.
ID
,
logredact
.
RedactText
(
err
.
Error
()))
}
}
...
...
backend/internal/service/token_refresh_service.go
View file @
c7137dff
...
...
@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials
,
err
=
refresher
.
Refresh
(
ctx
,
account
)
if
newCredentials
!=
nil
{
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
account
.
Credentials
=
newCredentials
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
if
saveErr
:=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
newCredentials
);
saveErr
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
}
}
...
...
backend/internal/service/token_refresh_service_test.go
View file @
c7137dff
...
...
@@ -14,19 +14,40 @@ import (
type
tokenRefreshAccountRepo
struct
{
mockAccountRepoForGemini
updateCalls
int
setErrorCalls
int
clearTempCalls
int
lastAccount
*
Account
updateErr
error
updateCalls
int
fullUpdateCalls
int
updateCredentialsCalls
int
setErrorCalls
int
clearTempCalls
int
lastAccount
*
Account
updateErr
error
}
func
(
r
*
tokenRefreshAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
r
.
updateCalls
++
r
.
fullUpdateCalls
++
r
.
lastAccount
=
account
return
r
.
updateErr
}
func
(
r
*
tokenRefreshAccountRepo
)
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCalls
++
r
.
updateCredentialsCalls
++
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
cloned
:=
cloneCredentials
(
credentials
)
if
r
.
accountsByID
!=
nil
{
if
acc
,
ok
:=
r
.
accountsByID
[
id
];
ok
&&
acc
!=
nil
{
acc
.
Credentials
=
cloned
r
.
lastAccount
=
acc
return
nil
}
}
r
.
lastAccount
=
&
Account
{
ID
:
id
,
Credentials
:
cloned
}
return
nil
}
func
(
r
*
tokenRefreshAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
++
return
nil
...
...
@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
0
,
repo
.
fullUpdateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
require
.
Equal
(
t
,
"new-token"
,
account
.
GetCredential
(
"access_token"
))
}
...
...
@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// 所有 OAuth 账户刷新后触发缓存失效
}
func
TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
)
resetAt
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
account
:=
&
Account
{
ID
:
17
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
RateLimitResetAt
:
&
resetAt
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
},
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"new-token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
0
,
repo
.
fullUpdateCalls
)
require
.
NotNil
(
t
,
account
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
,
*
account
.
RateLimitResetAt
,
time
.
Second
)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func
TestTokenRefreshService_RefreshWithRetry_UpdateFailed
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{
updateErr
:
errors
.
New
(
"update failed"
)}
...
...
@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
clearTempCalls
)
// DB 清除
require
.
Equal
(
t
,
1
,
repo
.
clearTempCalls
)
// DB 清除
require
.
Equal
(
t
,
1
,
tempCache
.
deleteCalls
)
// Redis 缓存也应清除
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment