Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c7137dff
Unverified
Commit
c7137dff
authored
Mar 24, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 24, 2026
Browse files
Merge pull request #1218 from LvyuanW/openai-runtime-recheck
fix(openai): prevent rescheduling rate-limited accounts
parents
5a3375ce
fef9259a
Changes
17
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/account_repo.go
View file @
c7137dff
...
@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
...
@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return
nil
return
nil
}
}
func
(
r
*
accountRepository
)
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
_
,
err
:=
r
.
client
.
Account
.
UpdateOneID
(
id
)
.
SetCredentials
(
normalizeJSONMap
(
credentials
))
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrAccountNotFound
,
nil
)
}
r
.
syncSchedulerAccountSnapshot
(
ctx
,
id
)
return
nil
}
func
(
r
*
accountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
accountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
groupIDs
,
err
:=
r
.
loadAccountGroupIDs
(
ctx
,
id
)
groupIDs
,
err
:=
r
.
loadAccountGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
if
err
!=
nil
{
...
...
backend/internal/service/account_credentials_persistence.go
0 → 100644
View file @
c7137dff
package
service
import
"context"
type
accountCredentialsUpdater
interface
{
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
}
func
persistAccountCredentials
(
ctx
context
.
Context
,
repo
AccountRepository
,
account
*
Account
,
credentials
map
[
string
]
any
)
error
{
if
repo
==
nil
||
account
==
nil
{
return
nil
}
account
.
Credentials
=
cloneCredentials
(
credentials
)
if
updater
,
ok
:=
any
(
repo
)
.
(
accountCredentialsUpdater
);
ok
{
return
updater
.
UpdateCredentials
(
ctx
,
account
.
ID
,
account
.
Credentials
)
}
return
repo
.
Update
(
ctx
,
account
)
}
func
cloneCredentials
(
in
map
[
string
]
any
)
map
[
string
]
any
{
if
in
==
nil
{
return
map
[
string
]
any
{}
}
out
:=
make
(
map
[
string
]
any
,
len
(
in
))
for
k
,
v
:=
range
in
{
out
[
k
]
=
v
}
return
out
}
backend/internal/service/antigravity_token_provider.go
View file @
c7137dff
...
@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
...
@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
p
.
markBackfillAttempted
(
account
.
ID
)
p
.
markBackfillAttempted
(
account
.
ID
)
if
projectID
,
err
:=
p
.
antigravityOAuthService
.
FillProjectID
(
ctx
,
account
,
accessToken
);
err
==
nil
&&
projectID
!=
""
{
if
projectID
,
err
:=
p
.
antigravityOAuthService
.
FillProjectID
(
ctx
,
account
,
accessToken
);
err
==
nil
&&
projectID
!=
""
{
account
.
Credentials
[
"project_id"
]
=
projectID
account
.
Credentials
[
"project_id"
]
=
projectID
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
if
updateErr
:=
p
ersistAccountCredentials
(
ctx
,
p
.
accountRepo
,
account
,
account
.
Credentials
);
updateErr
!=
nil
{
slog
.
Warn
(
"antigravity_project_id_backfill_persist_failed"
,
slog
.
Warn
(
"antigravity_project_id_backfill_persist_failed"
,
"account_id"
,
account
.
ID
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
,
"error"
,
updateErr
,
...
...
backend/internal/service/crs_sync_service.go
View file @
c7137dff
...
@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after creation
// 🔄 Refresh OAuth token after creation
if
targetType
==
AccountTypeOAuth
{
if
targetType
==
AccountTypeOAuth
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
account
.
Credentials
=
refreshedCreds
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
refreshedCreds
)
_
=
s
.
accountRepo
.
Update
(
ctx
,
account
)
}
}
}
}
item
.
Action
=
"created"
item
.
Action
=
"created"
...
@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update
// 🔄 Refresh OAuth token after update
if
targetType
==
AccountTypeOAuth
{
if
targetType
==
AccountTypeOAuth
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
existing
.
Credentials
=
refreshedCreds
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
existing
,
refreshedCreds
)
_
=
s
.
accountRepo
.
Update
(
ctx
,
existing
)
}
}
}
}
...
@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
}
// 🔄 Refresh OAuth token after creation
// 🔄 Refresh OAuth token after creation
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
account
.
Credentials
=
refreshedCreds
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
refreshedCreds
)
_
=
s
.
accountRepo
.
Update
(
ctx
,
account
)
}
}
item
.
Action
=
"created"
item
.
Action
=
"created"
result
.
Created
++
result
.
Created
++
...
@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update
// 🔄 Refresh OAuth token after update
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
existing
.
Credentials
=
refreshedCreds
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
existing
,
refreshedCreds
)
_
=
s
.
accountRepo
.
Update
(
ctx
,
existing
)
}
}
item
.
Action
=
"updated"
item
.
Action
=
"updated"
...
@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
continue
}
}
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
account
);
refreshedCreds
!=
nil
{
account
.
Credentials
=
refreshedCreds
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
refreshedCreds
)
_
=
s
.
accountRepo
.
Update
(
ctx
,
account
)
}
}
item
.
Action
=
"created"
item
.
Action
=
"created"
result
.
Created
++
result
.
Created
++
...
@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
}
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
if
refreshedCreds
:=
s
.
refreshOAuthToken
(
ctx
,
existing
);
refreshedCreds
!=
nil
{
existing
.
Credentials
=
refreshedCreds
_
=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
existing
,
refreshedCreds
)
_
=
s
.
accountRepo
.
Update
(
ctx
,
existing
)
}
}
item
.
Action
=
"updated"
item
.
Action
=
"updated"
...
...
backend/internal/service/gemini_token_provider.go
View file @
c7137dff
...
@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
tierID
!=
""
{
if
tierID
!=
""
{
account
.
Credentials
[
"tier_id"
]
=
tierID
account
.
Credentials
[
"tier_id"
]
=
tierID
}
}
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
_
=
p
ersistAccountCredentials
(
ctx
,
p
.
accountRepo
,
account
,
account
.
Credentials
)
}
}
}
}
...
...
backend/internal/service/oauth_refresh_api.go
View file @
c7137dff
...
@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
...
@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB
// 5. 设置版本号 + 更新 DB
if
newCredentials
!=
nil
{
if
newCredentials
!=
nil
{
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
freshAccount
.
Credentials
=
newCredentials
if
updateErr
:=
persistAccountCredentials
(
ctx
,
api
.
accountRepo
,
freshAccount
,
newCredentials
);
updateErr
!=
nil
{
if
updateErr
:=
api
.
accountRepo
.
Update
(
ctx
,
freshAccount
);
updateErr
!=
nil
{
slog
.
Error
(
"oauth_refresh_update_failed"
,
slog
.
Error
(
"oauth_refresh_update_failed"
,
"account_id"
,
freshAccount
.
ID
,
"account_id"
,
freshAccount
.
ID
,
"error"
,
updateErr
,
"error"
,
updateErr
,
...
...
backend/internal/service/oauth_refresh_api_test.go
View file @
c7137dff
...
@@ -16,10 +16,11 @@ import (
...
@@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type
refreshAPIAccountRepo
struct
{
type
refreshAPIAccountRepo
struct
{
mockAccountRepoForGemini
mockAccountRepoForGemini
account
*
Account
// returned by GetByID
account
*
Account
// returned by GetByID
getByIDErr
error
getByIDErr
error
updateErr
error
updateErr
error
updateCalls
int
updateCalls
int
updateCredentialsCalls
int
}
}
func
(
r
*
refreshAPIAccountRepo
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
Account
,
error
)
{
func
(
r
*
refreshAPIAccountRepo
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
Account
,
error
)
{
...
@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
...
@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return
r
.
updateErr
return
r
.
updateErr
}
}
func
(
r
*
refreshAPIAccountRepo
)
UpdateCredentials
(
_
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCalls
++
r
.
updateCredentialsCalls
++
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
if
r
.
account
==
nil
||
r
.
account
.
ID
!=
id
{
r
.
account
=
&
Account
{
ID
:
id
}
}
r
.
account
.
Credentials
=
cloneCredentials
(
credentials
)
return
nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type
refreshAPIExecutorStub
struct
{
type
refreshAPIExecutorStub
struct
{
needsRefresh
bool
needsRefresh
bool
...
@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
...
@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require
.
Equal
(
t
,
"new-token"
,
result
.
NewCredentials
[
"access_token"
])
require
.
Equal
(
t
,
"new-token"
,
result
.
NewCredentials
[
"access_token"
])
require
.
NotNil
(
t
,
result
.
NewCredentials
[
"_token_version"
])
// version stamp set
require
.
NotNil
(
t
,
result
.
NewCredentials
[
"_token_version"
])
// version stamp set
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB updated
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB updated
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
}
}
func
TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState
(
t
*
testing
.
T
)
{
resetAt
:=
time
.
Now
()
.
Add
(
45
*
time
.
Minute
)
account
:=
&
Account
{
ID
:
11
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
RateLimitResetAt
:
&
resetAt
,
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"safe-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
NotNil
(
t
,
repo
.
account
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
,
*
repo
.
account
.
RateLimitResetAt
,
time
.
Second
)
}
func
TestRefreshIfNeeded_LockHeld
(
t
*
testing
.
T
)
{
func
TestRefreshIfNeeded_LockHeld
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformAnthropic
}
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformAnthropic
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
...
@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
...
@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid_grant"
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid_grant"
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock still released via defer
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock still released via defer
}
}
...
@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
...
@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result
:=
MergeCredentials
(
old
,
new
)
result
:=
MergeCredentials
(
old
,
new
)
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
}
}
// ========== BuildClaudeAccountCredentials tests ==========
// ========== BuildClaudeAccountCredentials tests ==========
...
...
backend/internal/service/openai_account_scheduler.go
View file @
c7137dff
...
@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
...
@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
return
nil
,
nil
return
nil
,
nil
}
}
account
=
s
.
service
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
req
.
RequestedModel
)
if
account
==
nil
{
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
return
nil
,
nil
}
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
acquireErr
==
nil
&&
result
.
Acquired
{
if
acquireErr
==
nil
&&
result
.
Acquired
{
...
@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
...
@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
continue
continue
}
}
fresh
=
s
.
service
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
fresh
,
req
.
RequestedModel
)
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
continue
}
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
if
acquireErr
!=
nil
{
if
acquireErr
!=
nil
{
return
nil
,
len
(
candidates
),
topK
,
loadSkew
,
acquireErr
return
nil
,
len
(
candidates
),
topK
,
loadSkew
,
acquireErr
...
...
backend/internal/service/openai_account_scheduler_test.go
View file @
c7137dff
...
@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
...
@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require
.
Equal
(
t
,
int64
(
32002
),
account
.
ID
)
require
.
Equal
(
t
,
int64
(
32002
),
account
.
ID
)
}
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10103
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
staleSticky
:=
&
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
}
staleBackup
:=
&
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbSticky
:=
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbBackup
:=
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_db_runtime_recheck"
:
33001
}}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
staleSticky
,
staleBackup
},
accountsByID
:
map
[
int64
]
*
Account
{
33001
:
staleSticky
,
33002
:
staleBackup
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbSticky
,
dbBackup
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
""
,
"session_hash_db_runtime_recheck"
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
selection
)
require
.
NotNil
(
t
,
selection
.
Account
)
require
.
Equal
(
t
,
int64
(
33002
),
selection
.
Account
.
ID
)
require
.
Equal
(
t
,
openAIAccountScheduleLayerLoadBalance
,
decision
.
Layer
)
}
func
TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10104
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
stalePrimary
:=
&
Account
{
ID
:
34001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
}
staleSecondary
:=
&
Account
{
ID
:
34002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbPrimary
:=
Account
{
ID
:
34001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbSecondary
:=
Account
{
ID
:
34002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
stalePrimary
,
staleSecondary
},
accountsByID
:
map
[
int64
]
*
Account
{
34001
:
stalePrimary
,
34002
:
staleSecondary
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbPrimary
,
dbSecondary
}},
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
int64
(
34002
),
account
.
ID
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
groupID
:=
int64
(
9
)
groupID
:=
int64
(
9
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
c7137dff
...
@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
...
@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
return
nil
}
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
return
nil
}
// 刷新会话 TTL 并返回账号
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
// Refresh session TTL and return account
...
@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
...
@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if
fresh
==
nil
{
if
fresh
==
nil
{
continue
continue
}
}
fresh
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
fresh
,
requestedModel
)
if
fresh
==
nil
{
continue
}
// 选择优先级最高且最久未使用的账号
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
// Select highest priority and least recently used
...
@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
if
!
clearSticky
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
if
!
clearSticky
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
err
==
nil
&&
result
.
Acquired
{
if
account
==
nil
{
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
return
&
AccountSelectionResult
{
}
else
{
Account
:
account
,
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
Acquired
:
true
,
if
err
==
nil
&&
result
.
Acquired
{
ReleaseFunc
:
result
.
ReleaseFunc
,
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
},
nil
return
&
AccountSelectionResult
{
}
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
return
&
AccountSelectionResult
{
Account
:
account
,
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
accountID
,
AccountID
:
accountID
,
MaxConcurrency
:
account
.
Concurrency
,
MaxConcurrency
:
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
},
nil
},
nil
}
}
}
}
}
}
}
...
@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
...
@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return
fresh
return
fresh
}
}
func
(
s
*
OpenAIGatewayService
)
recheckSelectedOpenAIAccountFromDB
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
*
Account
{
if
account
==
nil
{
return
nil
}
if
s
.
schedulerSnapshot
==
nil
||
s
.
accountRepo
==
nil
{
return
account
}
latest
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
latest
==
nil
{
return
nil
}
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
latest
,
time
.
Now
())
if
!
latest
.
IsSchedulable
()
||
!
latest
.
IsOpenAI
()
{
return
nil
}
if
requestedModel
!=
""
&&
!
latest
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
return
latest
}
func
(
s
*
OpenAIGatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
func
(
s
*
OpenAIGatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
var
(
var
(
account
*
Account
account
*
Account
...
...
backend/internal/service/openai_ws_account_sticky_test.go
View file @
c7137dff
...
@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
...
@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require
.
Zero
(
t
,
boundAccountID
)
require
.
Zero
(
t
,
boundAccountID
)
}
}
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
24
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
staleAccount
:=
&
Account
{
ID
:
13
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
dbAccount
:=
Account
{
ID
:
13
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
RateLimitResetAt
:
&
rateLimitedUntil
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
cache
:=
&
stubGatewayCache
{}
store
:=
NewOpenAIWSStateStore
(
cache
)
cfg
:=
newOpenAIWSV2TestConfig
()
snapshotCache
:=
&
openAISnapshotCacheStub
{
accountsByID
:
map
[
int64
]
*
Account
{
dbAccount
.
ID
:
staleAccount
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbAccount
}},
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
openaiWSStateStore
:
store
,
schedulerSnapshot
:
&
SchedulerSnapshotService
{
cache
:
snapshotCache
},
}
require
.
NoError
(
t
,
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_prev_db_rl"
,
dbAccount
.
ID
,
time
.
Hour
))
selection
,
err
:=
svc
.
SelectAccountByPreviousResponseID
(
ctx
,
&
groupID
,
"resp_prev_db_rl"
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
selection
,
"DB 中已限流的账号不应继续命中 previous_response_id 粘连"
)
boundAccountID
,
getErr
:=
store
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_prev_db_rl"
)
require
.
NoError
(
t
,
getErr
)
require
.
Zero
(
t
,
boundAccountID
)
}
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
groupID
:=
int64
(
23
)
groupID
:=
int64
(
23
)
...
...
backend/internal/service/openai_ws_forwarder.go
View file @
c7137dff
...
@@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
...
@@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
,
nil
return
nil
,
nil
}
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
store
.
DeleteResponseAccount
(
ctx
,
derefGroupID
(
groupID
),
responseID
)
return
nil
,
nil
}
result
,
acquireErr
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
result
,
acquireErr
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
acquireErr
==
nil
&&
result
.
Acquired
{
if
acquireErr
==
nil
&&
result
.
Acquired
{
...
...
backend/internal/service/ratelimit_service.go
View file @
c7137dff
...
@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account
.
Credentials
=
make
(
map
[
string
]
any
)
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
}
account
.
Credentials
[
"expires_at"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
account
.
Credentials
[
"expires_at"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
account
.
Credentials
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_force_refresh_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
slog
.
Warn
(
"oauth_401_force_refresh_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
else
{
}
else
{
slog
.
Info
(
"oauth_401_force_refresh_set"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
)
slog
.
Info
(
"oauth_401_force_refresh_set"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
)
...
...
backend/internal/service/ratelimit_service_401_test.go
View file @
c7137dff
...
@@ -15,9 +15,11 @@ import (
...
@@ -15,9 +15,11 @@ import (
type
rateLimitAccountRepoStub
struct
{
type
rateLimitAccountRepoStub
struct
{
mockAccountRepoForGemini
mockAccountRepoForGemini
setErrorCalls
int
setErrorCalls
int
tempCalls
int
tempCalls
int
lastErrorMsg
string
updateCredentialsCalls
int
lastCredentials
map
[
string
]
any
lastErrorMsg
string
}
}
func
(
r
*
rateLimitAccountRepoStub
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
func
(
r
*
rateLimitAccountRepoStub
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
...
@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
...
@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return
nil
return
nil
}
}
func
(
r
*
rateLimitAccountRepoStub
)
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCredentialsCalls
++
r
.
lastCredentials
=
cloneCredentials
(
credentials
)
return
nil
}
type
tokenCacheInvalidatorRecorder
struct
{
type
tokenCacheInvalidatorRecorder
struct
{
accounts
[]
*
Account
accounts
[]
*
Account
err
error
err
error
...
@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
...
@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require
.
True
(
t
,
shouldDisable
)
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
}
}
...
@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
...
@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Empty
(
t
,
invalidator
.
accounts
)
require
.
Empty
(
t
,
invalidator
.
accounts
)
}
}
func
TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
NotEmpty
(
t
,
repo
.
lastCredentials
[
"expires_at"
])
}
backend/internal/service/sora_sdk_client.go
View file @
c7137dff
...
@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
...
@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
}
}
if
c
.
accountRepo
!=
nil
{
if
c
.
accountRepo
!=
nil
{
if
err
:=
c
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
&&
c
.
debugEnabled
()
{
if
err
:=
persistAccountCredentials
(
ctx
,
c
.
accountRepo
,
account
,
account
.
Credentials
);
err
!=
nil
&&
c
.
debugEnabled
()
{
c
.
debugLogf
(
"persist_recovered_token_failed account_id=%d err=%s"
,
account
.
ID
,
logredact
.
RedactText
(
err
.
Error
()))
c
.
debugLogf
(
"persist_recovered_token_failed account_id=%d err=%s"
,
account
.
ID
,
logredact
.
RedactText
(
err
.
Error
()))
}
}
}
}
...
...
backend/internal/service/token_refresh_service.go
View file @
c7137dff
...
@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
...
@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials
,
err
=
refresher
.
Refresh
(
ctx
,
account
)
newCredentials
,
err
=
refresher
.
Refresh
(
ctx
,
account
)
if
newCredentials
!=
nil
{
if
newCredentials
!=
nil
{
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
account
.
Credentials
=
newCredentials
if
saveErr
:=
persistAccountCredentials
(
ctx
,
s
.
accountRepo
,
account
,
newCredentials
);
saveErr
!=
nil
{
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
}
}
}
}
...
...
backend/internal/service/token_refresh_service_test.go
View file @
c7137dff
...
@@ -14,19 +14,40 @@ import (
...
@@ -14,19 +14,40 @@ import (
type
tokenRefreshAccountRepo
struct
{
type
tokenRefreshAccountRepo
struct
{
mockAccountRepoForGemini
mockAccountRepoForGemini
updateCalls
int
updateCalls
int
setErrorCalls
int
fullUpdateCalls
int
clearTempCalls
int
updateCredentialsCalls
int
lastAccount
*
Account
setErrorCalls
int
updateErr
error
clearTempCalls
int
lastAccount
*
Account
updateErr
error
}
}
func
(
r
*
tokenRefreshAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
func
(
r
*
tokenRefreshAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
r
.
updateCalls
++
r
.
updateCalls
++
r
.
fullUpdateCalls
++
r
.
lastAccount
=
account
r
.
lastAccount
=
account
return
r
.
updateErr
return
r
.
updateErr
}
}
func
(
r
*
tokenRefreshAccountRepo
)
UpdateCredentials
(
ctx
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCalls
++
r
.
updateCredentialsCalls
++
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
cloned
:=
cloneCredentials
(
credentials
)
if
r
.
accountsByID
!=
nil
{
if
acc
,
ok
:=
r
.
accountsByID
[
id
];
ok
&&
acc
!=
nil
{
acc
.
Credentials
=
cloned
r
.
lastAccount
=
acc
return
nil
}
}
r
.
lastAccount
=
&
Account
{
ID
:
id
,
Credentials
:
cloned
}
return
nil
}
func
(
r
*
tokenRefreshAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
func
(
r
*
tokenRefreshAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
++
r
.
setErrorCalls
++
return
nil
return
nil
...
@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
...
@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
0
,
repo
.
fullUpdateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
require
.
Equal
(
t
,
"new-token"
,
account
.
GetCredential
(
"access_token"
))
require
.
Equal
(
t
,
"new-token"
,
account
.
GetCredential
(
"access_token"
))
}
}
...
@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
...
@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// 所有 OAuth 账户刷新后触发缓存失效
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// 所有 OAuth 账户刷新后触发缓存失效
}
}
func
TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
)
resetAt
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
account
:=
&
Account
{
ID
:
17
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
RateLimitResetAt
:
&
resetAt
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
},
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"new-token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
0
,
repo
.
fullUpdateCalls
)
require
.
NotNil
(
t
,
account
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
,
*
account
.
RateLimitResetAt
,
time
.
Second
)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func
TestTokenRefreshService_RefreshWithRetry_UpdateFailed
(
t
*
testing
.
T
)
{
func
TestTokenRefreshService_RefreshWithRetry_UpdateFailed
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{
updateErr
:
errors
.
New
(
"update failed"
)}
repo
:=
&
tokenRefreshAccountRepo
{
updateErr
:
errors
.
New
(
"update failed"
)}
...
@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
...
@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
clearTempCalls
)
// DB 清除
require
.
Equal
(
t
,
1
,
repo
.
clearTempCalls
)
// DB 清除
require
.
Equal
(
t
,
1
,
tempCache
.
deleteCalls
)
// Redis 缓存也应清除
require
.
Equal
(
t
,
1
,
tempCache
.
deleteCalls
)
// Redis 缓存也应清除
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment