Unverified Commit c7137dff authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1218 from LvyuanW/openai-runtime-recheck

fix(openai): prevent rescheduling rate-limited accounts
parents 5a3375ce fef9259a
...@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return nil return nil
} }
func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
_, err := r.client.Account.UpdateOneID(id).
SetCredentials(normalizeJSONMap(credentials)).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id) groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil { if err != nil {
......
package service
import "context"
type accountCredentialsUpdater interface {
UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error
}
func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error {
if repo == nil || account == nil {
return nil
}
account.Credentials = cloneCredentials(credentials)
if updater, ok := any(repo).(accountCredentialsUpdater); ok {
return updater.UpdateCredentials(ctx, account.ID, account.Credentials)
}
return repo.Update(ctx, account)
}
func cloneCredentials(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
...@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
p.markBackfillAttempted(account.ID) p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil {
slog.Warn("antigravity_project_id_backfill_persist_failed", slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID, "account_id", account.ID,
"error", updateErr, "error", updateErr,
......
...@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if targetType == AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
} }
item.Action = "created" item.Action = "created"
...@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if targetType == AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
} }
...@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
item.Action = "created" item.Action = "created"
result.Created++ result.Created++
...@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
item.Action = "updated" item.Action = "updated"
...@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
item.Action = "created" item.Action = "created"
result.Created++ result.Created++
...@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
item.Action = "updated" item.Action = "updated"
......
...@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if tierID != "" { if tierID != "" {
account.Credentials["tier_id"] = tierID account.Credentials["tier_id"] = tierID
} }
_ = p.accountRepo.Update(ctx, account) _ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials)
} }
} }
......
...@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ...@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB // 5. 设置版本号 + 更新 DB
if newCredentials != nil { if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli() newCredentials["_token_version"] = time.Now().UnixMilli()
freshAccount.Credentials = newCredentials if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
slog.Error("oauth_refresh_update_failed", slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID, "account_id", freshAccount.ID,
"error", updateErr, "error", updateErr,
......
...@@ -16,10 +16,11 @@ import ( ...@@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. // refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type refreshAPIAccountRepo struct { type refreshAPIAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
account *Account // returned by GetByID account *Account // returned by GetByID
getByIDErr error getByIDErr error
updateErr error updateErr error
updateCalls int updateCalls int
updateCredentialsCalls int
} }
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
...@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { ...@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return r.updateErr return r.updateErr
} }
func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
if r.account == nil || r.account.ID != id {
r.account = &Account{ID: id}
}
r.account.Credentials = cloneCredentials(credentials)
return nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. // refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type refreshAPIExecutorStub struct { type refreshAPIExecutorStub struct {
needsRefresh bool needsRefresh bool
...@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) { ...@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require.Equal(t, "new-token", result.NewCredentials["access_token"]) require.Equal(t, "new-token", result.NewCredentials["access_token"])
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
require.Equal(t, 1, repo.updateCalls) // DB updated require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 1, cache.releaseCalls) // lock released require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, executor.refreshCalls) require.Equal(t, 1, executor.refreshCalls)
} }
func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) {
resetAt := time.Now().Add(45 * time.Minute)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "safe-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotNil(t, repo.account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second)
}
func TestRefreshIfNeeded_LockHeld(t *testing.T) { func TestRefreshIfNeeded_LockHeld(t *testing.T) {
account := &Account{ID: 2, Platform: PlatformAnthropic} account := &Account{ID: 2, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account} repo := &refreshAPIAccountRepo{account: account}
...@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) { ...@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.Contains(t, err.Error(), "invalid_grant") require.Contains(t, err.Error(), "invalid_grant")
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
} }
...@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) { ...@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result := MergeCredentials(old, new) result := MergeCredentials(old, new)
require.Equal(t, "new-token", result["access_token"]) // overridden require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
} }
// ========== BuildClaudeAccountCredentials tests ========== // ========== BuildClaudeAccountCredentials tests ==========
......
...@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( ...@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil return nil, nil
} }
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired { if acquireErr == nil && result.Acquired {
...@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ...@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue continue
} }
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil { if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr return nil, len(candidates), topK, loadSkew, acquireErr
......
...@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa ...@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID) require.Equal(t, int64(32002), account.ID)
} }
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(10103)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(33002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10104)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{stalePrimary, staleSecondary},
accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(34002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(9) groupID := int64(9)
......
...@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID ...@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil return nil
} }
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号 // 刷新会话 TTL 并返回账号
// Refresh session TTL and return account // Refresh session TTL and return account
...@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ ...@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil { if fresh == nil {
continue continue
} }
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
if fresh == nil {
continue
}
// 选择优先级最高且最久未使用的账号 // 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used // Select highest priority and least recently used
...@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
} }
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) { (requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if err == nil && result.Acquired { if account == nil {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return &AccountSelectionResult{ } else {
Account: account, result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
Acquired: true, if err == nil && result.Acquired {
ReleaseFunc: result.ReleaseFunc, _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
}, nil return &AccountSelectionResult{
} Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
AccountID: accountID, AccountID: accountID,
MaxConcurrency: account.Concurrency, MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout, Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting, MaxWaiting: cfg.StickySessionMaxWaiting,
}, },
}, nil }, nil
}
} }
} }
} }
...@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. ...@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return fresh return fresh
} }
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
return account
}
latest, err := s.accountRepo.GetByID(ctx, account.ID)
if err != nil || latest == nil {
return nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
return nil
}
return latest
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
var ( var (
account *Account account *Account
......
...@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss( ...@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.Zero(t, boundAccountID) require.Zero(t, boundAccountID)
} }
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(24)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleAccount := &Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
dbAccount := Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
RateLimitResetAt: &rateLimitedUntil,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
snapshotCache := &openAISnapshotCacheStub{
accountsByID: map[int64]*Account{dbAccount.ID: staleAccount},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
require.NoError(t, getErr)
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(23) groupID := int64(23)
......
...@@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( ...@@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil return nil, nil
} }
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired { if acquireErr == nil && result.Acquired {
......
...@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil { if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else { } else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
......
...@@ -15,9 +15,11 @@ import ( ...@@ -15,9 +15,11 @@ import (
type rateLimitAccountRepoStub struct { type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini mockAccountRepoForGemini
setErrorCalls int setErrorCalls int
tempCalls int tempCalls int
lastErrorMsg string updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
} }
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
...@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id ...@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return nil return nil
} }
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct { type tokenCacheInvalidatorRecorder struct {
accounts []*Account accounts []*Account
err error err error
...@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin ...@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
} }
...@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { ...@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts) require.Empty(t, invalidator.accounts)
} }
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}
...@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun ...@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
} }
if c.accountRepo != nil { if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
} }
} }
......
...@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials, err = refresher.Refresh(ctx, account) newCredentials, err = refresher.Refresh(ctx, account)
if newCredentials != nil { if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli() newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr) return fmt.Errorf("failed to save credentials: %w", saveErr)
} }
} }
......
...@@ -14,19 +14,40 @@ import ( ...@@ -14,19 +14,40 @@ import (
type tokenRefreshAccountRepo struct { type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
updateCalls int updateCalls int
setErrorCalls int fullUpdateCalls int
clearTempCalls int updateCredentialsCalls int
lastAccount *Account setErrorCalls int
updateErr error clearTempCalls int
lastAccount *Account
updateErr error
} }
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++ r.updateCalls++
r.fullUpdateCalls++
r.lastAccount = account r.lastAccount = account
return r.updateErr return r.updateErr
} }
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
cloned := cloneCredentials(credentials)
if r.accountsByID != nil {
if acc, ok := r.accountsByID[id]; ok && acc != nil {
acc.Credentials = cloned
r.lastAccount = acc
return nil
}
}
r.lastAccount = &Account{ID: id, Credentials: cloned}
return nil
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++ r.setErrorCalls++
return nil return nil
...@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { ...@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.Equal(t, 1, invalidator.calls) require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token")) require.Equal(t, "new-token", account.GetCredential("access_token"))
} }
...@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { ...@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
} }
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
Credentials: map[string]any{
"access_token": "old-token",
},
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.NotNil(t, account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
...@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing ...@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除 require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment