Commit fef9259a authored by Wang Lvyuan's avatar Wang Lvyuan
Browse files

fix(openai): recheck runtime state from db before final account selection

parent ad7c1072
......@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
......@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
......
......@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(10103)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(33002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10104)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{stalePrimary, staleSecondary},
accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(34002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
......
......@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
......@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
if fresh == nil {
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
......@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
......@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return fresh
}
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
return account
}
latest, err := s.accountRepo.GetByID(ctx, account.ID)
if err != nil || latest == nil {
return nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
return nil
}
return latest
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
var (
account *Account
......
......@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(24)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleAccount := &Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
dbAccount := Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
RateLimitResetAt: &rateLimitedUntil,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
snapshotCache := &openAISnapshotCacheStub{
accountsByID: map[int64]*Account{dbAccount.ID: staleAccount},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
require.NoError(t, getErr)
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
......
......@@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment