Commit ace08206 authored by IanShaw027's avatar IanShaw027
Browse files

fix: honor ws transport when scheduler is disabled

parent 65efef1e
...@@ -767,14 +767,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ...@@ -767,14 +767,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
} }
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 if s == nil || s.service == nil {
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
return false return false
} }
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
} }
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
...@@ -899,9 +895,35 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( ...@@ -899,9 +895,35 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
decision := OpenAIAccountScheduleDecision{} decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx) scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil { if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
return selection, decision, err
}
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
if err != nil {
return nil, decision, err
}
if selection == nil || selection.Account == nil {
return selection, decision, nil
}
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
return selection, decision, nil
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if effectiveExcludedIDs == nil {
effectiveExcludedIDs = make(map[int64]struct{})
}
if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
return nil, decision, ErrNoAvailableAccounts
}
effectiveExcludedIDs[selection.Account.ID] = struct{}{}
}
} }
var stickyAccountID int64 var stickyAccountID int64
...@@ -922,6 +944,27 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( ...@@ -922,6 +944,27 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}) })
} }
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
if len(excludedIDs) == 0 {
return nil
}
cloned := make(map[int64]struct{}, len(excludedIDs))
for id := range excludedIDs {
cloned[id] = struct{}{}
}
return cloned
}
func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || account == nil {
return false
}
return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler(context.Background()) scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil { if scheduler == nil {
......
...@@ -298,6 +298,98 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega ...@@ -298,6 +298,98 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
require.False(t, decision.StickyPreviousHit) require.False(t, decision.StickyPreviousHit)
} }
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10108)
accounts := []Account{
{
ID: 36011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 36012,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36012), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10109)
accounts := []Account{
{
ID: 36021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.ErrorContains(t, err, "no available OpenAI accounts")
require.Nil(t, selection)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest() resetOpenAIAdvancedSchedulerSettingCacheForTest()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment