Unverified Commit 9d795061 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #682 from mt21625457/pr/all-code-sync-20260228

feat(openai-ws): support websocket mode v2, optimize relay performance, enhance sora
parents bfc7b339 1d1fc019
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) {
payload := map[string]any{
"model": "gpt-5.3-codex",
"prompt_cache_key": "pcache_123",
"include": []any{"reasoning.encrypted_content"},
"text": map[string]any{
"verbosity": "low",
},
"tools": []any{map[string]any{"type": "function"}},
}
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3)
require.Equal(t, "trim_optional_fields", strategy)
require.Contains(t, removed, "include")
require.NotContains(t, removed, "prompt_cache_key")
require.Equal(t, "pcache_123", payload["prompt_cache_key"])
require.NotContains(t, payload, "include")
require.Contains(t, payload, "text")
}
func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) {
payload := map[string]any{
"prompt_cache_key": "pcache_456",
"instructions": "long instructions",
"tools": []any{map[string]any{"type": "function"}},
"parallel_tool_calls": true,
"tool_choice": "auto",
"include": []any{"reasoning.encrypted_content"},
"text": map[string]any{"verbosity": "high"},
}
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6)
require.Equal(t, "trim_optional_fields", strategy)
require.Contains(t, removed, "include")
require.NotContains(t, removed, "prompt_cache_key")
require.Equal(t, "pcache_456", payload["prompt_cache_key"])
require.Contains(t, payload, "instructions")
require.Contains(t, payload, "tools")
require.Contains(t, payload, "tool_choice")
require.Contains(t, payload, "parallel_tool_calls")
require.Contains(t, payload, "text")
}
package service
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
gin.SetMode(gin.TestMode)
type receivedPayload struct {
Type string
PreviousResponseID string
StreamExists bool
Stream bool
}
receivedCh := make(chan receivedPayload, 1)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
requestJSON := requestToJSONString(request)
receivedCh <- receivedPayload{
Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()),
PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()),
StreamExists: gjson.Get(requestJSON, "stream").Exists(),
Stream: gjson.Get(requestJSON, "stream").Bool(),
}
if err := conn.WriteJSON(map[string]any{
"type": "response.created",
"response": map[string]any{
"id": "resp_new_1",
"model": "gpt-5.1",
},
}); err != nil {
t.Errorf("write response.created failed: %v", err)
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_new_1",
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 12,
"output_tokens": 7,
"input_tokens_details": map[string]any{
"cached_tokens": 3,
},
},
},
}); err != nil {
t.Errorf("write response.completed failed: %v", err)
return
}
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
groupID := int64(1001)
c.Set("api_key", &APIKey{GroupID: &groupID})
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: cache,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 9,
Name: "openai-ws",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.Equal(t, "resp_new_1", result.RequestID)
require.True(t, result.OpenAIWSMode)
require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游")
received := <-receivedCh
require.Equal(t, "response.create", received.Type)
require.Equal(t, "resp_prev_1", received.PreviousResponseID)
require.True(t, received.StreamExists, "WS 请求应携带 stream 字段")
require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义")
store := svc.getOpenAIWSStateStore()
mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1")
require.NoError(t, getErr)
require.Equal(t, account.ID, mappedAccountID)
connID, ok := store.GetResponseConn("resp_new_1")
require.True(t, ok)
require.NotEmpty(t, connID)
responseBody := rec.Body.Bytes()
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
}
func requestToJSONString(payload map[string]any) string {
if len(payload) == 0 {
return "{}"
}
b, err := json.Marshal(payload)
if err != nil {
return "{}"
}
return string(b)
}
func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) {
require.NotPanics(t, func() {
logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil)
})
require.NotPanics(t, func() {
logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed"))
})
}
func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
groupID := int64(3001)
c.Set("api_key", &APIKey{GroupID: &groupID})
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 1301,
Name: "openai-rewrite",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"model_mapping": map[string]any{
"custom-original-model": "gpt-5.1",
},
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_model_tool_1", result.RequestID)
require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型")
require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范")
}
func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) {
payload := map[string]any{
"type": nil,
"model": 123,
"prompt_cache_key": " cache-key ",
"previous_response_id": []byte(" resp_1 "),
}
require.Equal(t, "", openAIWSPayloadString(payload, "type"))
require.Equal(t, "", openAIWSPayloadString(payload, "model"))
require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key"))
require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id"))
}
func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
gin.SetMode(gin.TestMode)
var upgradeCount atomic.Int64
var sequence atomic.Int64
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgradeCount.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
for {
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
return
}
idx := sequence.Add(1)
responseID := "resp_reuse_" + strconv.FormatInt(idx, 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.created",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
},
}); err != nil {
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 1,
},
},
}); err != nil {
return
}
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 19,
Name: "openai-ws",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
groupID := int64(2001)
c.Set("api_key", &APIKey{GroupID: &groupID})
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
}
require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
metrics := svc.SnapshotOpenAIWSPoolMetrics()
require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
}
func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
c.Request.Header.Set("session_id", "sess-oauth-1")
c.Request.Header.Set("conversation_id", "conv-oauth-1")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 29,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token-1",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_oauth_1", result.RequestID)
require.NotNil(t, captureConn.lastWrite)
requestJSON := requestToJSONString(captureConn.lastWrite)
require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段")
require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false")
require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id"))
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
}
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 31,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token-1",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_prompt_cache_key", result.RequestID)
require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id"))
require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
require.NotNil(t, captureConn.lastWrite)
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
}
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 39,
Name: "openai-ws-v1",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1/responses",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "ws v1")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "WSv1")
require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
}
func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
var connIndex atomic.Int64
headersCh := make(chan http.Header, 4)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
idx := connIndex.Add(1)
headersCh <- cloneHeader(r.Header)
respHeader := http.Header{}
if idx == 1 {
respHeader.Set("x-codex-turn-state", "turn_state_first")
}
conn, err := upgrader.Upgrade(w, r, respHeader)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
responseID := "resp_turn_" + strconv.FormatInt(idx, 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 1,
},
},
}); err != nil {
t.Errorf("write response.completed failed: %v", err)
return
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 49,
Name: "openai-turn-state",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session_turn_state")
c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1")
result1, err := svc.Forward(context.Background(), c1, account, reqBody)
require.NoError(t, err)
require.NotNil(t, result1)
sessionHash := svc.GenerateSessionHash(c1, reqBody)
store := svc.getOpenAIWSStateStore()
turnState, ok := store.GetSessionTurnState(0, sessionHash)
require.True(t, ok)
require.Equal(t, "turn_state_first", turnState)
// 主动淘汰连接,模拟下一次请求发生重连。
connID, hasConn := store.GetResponseConn(result1.RequestID)
require.True(t, hasConn)
svc.getOpenAIWSConnPool().evictConn(account.ID, connID)
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session_turn_state")
c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2")
result2, err := svc.Forward(context.Background(), c2, account, reqBody)
require.NoError(t, err)
require.NotNil(t, result2)
firstHandshakeHeaders := <-headersCh
secondHandshakeHeaders := <-headersCh
require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State"))
}
func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("session_id", "session-prewarm")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 59,
Name: "openai-prewarm",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_main_1", result.RequestID)
require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求")
firstWrite := requestToJSONString(captureConn.writes[0])
secondWrite := requestToJSONString(captureConn.writes[1])
require.True(t, gjson.Get(firstWrite, "generate").Exists())
require.False(t, gjson.Get(firstWrite, "generate").Bool())
require.False(t, gjson.Get(secondWrite, "generate").Exists())
}
func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
svc := &OpenAIGatewayService{
cfg: cfg,
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 601,
Name: "openai-prewarm-timeout",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
}
conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{
readDelay: 200 * time.Millisecond,
}, nil)
lease := &openAIWSConnLease{
accountID: account.ID,
conn: conn,
}
payload := map[string]any{
"type": "response.create",
"model": "gpt-5.1",
}
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
defer cancel()
start := time.Now()
err := svc.performOpenAIWSGeneratePrewarm(
ctx,
lease,
OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
payload,
"",
map[string]any{"model": "gpt-5.1"},
account,
nil,
0,
)
elapsed := time.Since(start)
require.Error(t, err)
require.Contains(t, err.Error(), "prewarm_read_event")
require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout")
}
func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 69,
Name: "openai-turn-metadata",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session-metadata-reuse")
c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1")
result1, err := svc.Forward(context.Background(), c1, account, body)
require.NoError(t, err)
require.NotNil(t, result1)
require.Equal(t, "resp_meta_1", result1.RequestID)
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session-metadata-reuse")
c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2")
result2, err := svc.Forward(context.Background(), c2, account, body)
require.NoError(t, err)
require.NotNil(t, result2)
require.Equal(t, "resp_meta_2", result2.RequestID)
require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
require.Len(t, captureConn.writes, 2)
firstWrite := requestToJSONString(captureConn.writes[0])
secondWrite := requestToJSONString(captureConn.writes[1])
require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
}
func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) {
gin.SetMode(gin.TestMode)
var upgradeCount atomic.Int64
var sequence atomic.Int64
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgradeCount.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
for {
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
return
}
responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
},
},
}); err != nil {
return
}
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 79,
Name: "openai-store-false",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session_store_false_a")
result1, err := svc.Forward(context.Background(), c1, account, body)
require.NoError(t, err)
require.NotNil(t, result1)
require.Equal(t, int64(1), upgradeCount.Load())
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session_store_false_a")
result2, err := svc.Forward(context.Background(), c2, account, body)
require.NoError(t, err)
require.NotNil(t, result2)
require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接")
rec3 := httptest.NewRecorder()
c3, _ := gin.CreateTestContext(rec3)
c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c3.Request.Header.Set("session_id", "session_store_false_b")
result3, err := svc.Forward(context.Background(), c3, account, body)
require.NoError(t, err)
require.NotNil(t, result3)
require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖")
}
func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) {
gin.SetMode(gin.TestMode)
var upgradeCount atomic.Int64
var sequence atomic.Int64
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgradeCount.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
for {
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
return
}
responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
},
},
}); err != nil {
return
}
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 80,
Name: "openai-store-false-reuse",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session_store_false_reuse_a")
result1, err := svc.Forward(context.Background(), c1, account, body)
require.NoError(t, err)
require.NotNil(t, result1)
require.Equal(t, int64(1), upgradeCount.Load())
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session_store_false_reuse_b")
result2, err := svc.Forward(context.Background(), c2, account, body)
require.NoError(t, err)
require.NotNil(t, result2)
require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接")
}
func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
readDelays: []time.Duration{
700 * time.Millisecond,
700 * time.Millisecond,
},
events: [][]byte{
[]byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 81,
Name: "openai-read-timeout",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_timeout_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP")
}
type openAIWSCaptureDialer struct {
mu sync.Mutex
conn *openAIWSCaptureConn
lastHeaders http.Header
handshake http.Header
dialCount int
}
func (d *openAIWSCaptureDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = proxyURL
d.mu.Lock()
d.lastHeaders = cloneHeader(headers)
d.dialCount++
respHeaders := cloneHeader(d.handshake)
d.mu.Unlock()
return d.conn, 0, respHeaders, nil
}
func (d *openAIWSCaptureDialer) DialCount() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.dialCount
}
type openAIWSCaptureConn struct {
mu sync.Mutex
readDelays []time.Duration
events [][]byte
lastWrite map[string]any
writes []map[string]any
closed bool
}
func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
_ = ctx
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errOpenAIWSConnClosed
}
switch payload := value.(type) {
case map[string]any:
c.lastWrite = cloneMapStringAny(payload)
c.writes = append(c.writes, cloneMapStringAny(payload))
case json.RawMessage:
var parsed map[string]any
if err := json.Unmarshal(payload, &parsed); err == nil {
c.lastWrite = cloneMapStringAny(parsed)
c.writes = append(c.writes, cloneMapStringAny(parsed))
}
case []byte:
var parsed map[string]any
if err := json.Unmarshal(payload, &parsed); err == nil {
c.lastWrite = cloneMapStringAny(parsed)
c.writes = append(c.writes, cloneMapStringAny(parsed))
}
}
return nil
}
func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
if ctx == nil {
ctx = context.Background()
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, errOpenAIWSConnClosed
}
if len(c.events) == 0 {
c.mu.Unlock()
return nil, io.EOF
}
delay := time.Duration(0)
if len(c.readDelays) > 0 {
delay = c.readDelays[0]
c.readDelays = c.readDelays[1:]
}
event := c.events[0]
c.events = c.events[1:]
c.mu.Unlock()
if delay > 0 {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
}
}
return event, nil
}
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
_ = ctx
return nil
}
func (c *openAIWSCaptureConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func cloneMapStringAny(src map[string]any) map[string]any {
if src == nil {
return nil
}
dst := make(map[string]any, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
package service
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"golang.org/x/sync/errgroup"
)
const (
openAIWSConnMaxAge = 60 * time.Minute
openAIWSConnHealthCheckIdle = 90 * time.Second
openAIWSConnHealthCheckTO = 2 * time.Second
openAIWSConnPrewarmExtraDelay = 2 * time.Second
openAIWSAcquireCleanupInterval = 3 * time.Second
openAIWSBackgroundPingInterval = 30 * time.Second
openAIWSBackgroundSweepTicker = 30 * time.Second
openAIWSPrewarmFailureWindow = 30 * time.Second
openAIWSPrewarmFailureSuppress = 2
)
var (
errOpenAIWSConnClosed = errors.New("openai ws connection closed")
errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
)
type openAIWSDialError struct {
StatusCode int
ResponseHeaders http.Header
Err error
}
func (e *openAIWSDialError) Error() string {
if e == nil {
return ""
}
if e.StatusCode > 0 {
return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
}
return fmt.Sprintf("openai ws dial failed: %v", e.Err)
}
func (e *openAIWSDialError) Unwrap() error {
if e == nil {
return nil
}
return e.Err
}
type openAIWSAcquireRequest struct {
Account *Account
WSURL string
Headers http.Header
ProxyURL string
PreferredConnID string
// ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
ForceNewConn bool
// ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。
ForcePreferredConn bool
}
type openAIWSConnLease struct {
pool *openAIWSConnPool
accountID int64
conn *openAIWSConn
queueWait time.Duration
connPick time.Duration
reused bool
released atomic.Bool
}
func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) {
if l == nil || l.conn == nil {
return nil, errOpenAIWSConnClosed
}
if l.released.Load() {
return nil, errOpenAIWSConnClosed
}
return l.conn, nil
}
func (l *openAIWSConnLease) ConnID() string {
if l == nil || l.conn == nil {
return ""
}
return l.conn.id
}
func (l *openAIWSConnLease) QueueWaitDuration() time.Duration {
if l == nil {
return 0
}
return l.queueWait
}
func (l *openAIWSConnLease) ConnPickDuration() time.Duration {
if l == nil {
return 0
}
return l.connPick
}
func (l *openAIWSConnLease) Reused() bool {
if l == nil {
return false
}
return l.reused
}
func (l *openAIWSConnLease) HandshakeHeader(name string) string {
if l == nil || l.conn == nil {
return ""
}
return l.conn.handshakeHeader(name)
}
func (l *openAIWSConnLease) IsPrewarmed() bool {
if l == nil || l.conn == nil {
return false
}
return l.conn.isPrewarmed()
}
func (l *openAIWSConnLease) MarkPrewarmed() {
if l == nil || l.conn == nil {
return
}
l.conn.markPrewarmed()
}
func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSONWithTimeout(context.Background(), value, timeout)
}
func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSONWithTimeout(ctx, value, timeout)
}
func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSON(value, ctx)
}
func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessageWithTimeout(timeout)
}
func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessage(ctx)
}
func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessageWithContextTimeout(ctx, timeout)
}
func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.pingWithTimeout(timeout)
}
func (l *openAIWSConnLease) MarkBroken() {
if l == nil || l.pool == nil || l.conn == nil || l.released.Load() {
return
}
l.pool.evictConn(l.accountID, l.conn.id)
}
func (l *openAIWSConnLease) Release() {
if l == nil || l.conn == nil {
return
}
if !l.released.CompareAndSwap(false, true) {
return
}
l.conn.release()
}
type openAIWSConn struct {
id string
ws openAIWSClientConn
handshakeHeaders http.Header
leaseCh chan struct{}
closedCh chan struct{}
closeOnce sync.Once
readMu sync.Mutex
writeMu sync.Mutex
waiters atomic.Int32
createdAtNano atomic.Int64
lastUsedNano atomic.Int64
prewarmed atomic.Bool
}
func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn {
now := time.Now()
conn := &openAIWSConn{
id: id,
ws: ws,
handshakeHeaders: cloneHeader(handshakeHeaders),
leaseCh: make(chan struct{}, 1),
closedCh: make(chan struct{}),
}
conn.leaseCh <- struct{}{}
conn.createdAtNano.Store(now.UnixNano())
conn.lastUsedNano.Store(now.UnixNano())
return conn
}
func (c *openAIWSConn) tryAcquire() bool {
if c == nil {
return false
}
select {
case <-c.closedCh:
return false
default:
}
select {
case <-c.leaseCh:
select {
case <-c.closedCh:
c.release()
return false
default:
}
return true
default:
return false
}
}
func (c *openAIWSConn) acquire(ctx context.Context) error {
if c == nil {
return errOpenAIWSConnClosed
}
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.closedCh:
return errOpenAIWSConnClosed
case <-c.leaseCh:
select {
case <-c.closedCh:
c.release()
return errOpenAIWSConnClosed
default:
}
return nil
}
}
}
func (c *openAIWSConn) release() {
if c == nil {
return
}
select {
case c.leaseCh <- struct{}{}:
default:
}
c.touch()
}
func (c *openAIWSConn) close() {
if c == nil {
return
}
c.closeOnce.Do(func() {
close(c.closedCh)
if c.ws != nil {
_ = c.ws.Close()
}
select {
case c.leaseCh <- struct{}{}:
default:
}
})
}
func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error {
if c == nil {
return errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return errOpenAIWSConnClosed
default:
}
writeCtx := parent
if writeCtx == nil {
writeCtx = context.Background()
}
if timeout <= 0 {
return c.writeJSON(value, writeCtx)
}
var cancel context.CancelFunc
writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
defer cancel()
return c.writeJSON(value, writeCtx)
}
func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.ws == nil {
return errOpenAIWSConnClosed
}
if writeCtx == nil {
writeCtx = context.Background()
}
if err := c.ws.WriteJSON(writeCtx, value); err != nil {
return err
}
c.touch()
return nil
}
func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) {
return c.readMessageWithContextTimeout(context.Background(), timeout)
}
func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) {
if c == nil {
return nil, errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return nil, errOpenAIWSConnClosed
default:
}
if parent == nil {
parent = context.Background()
}
if timeout <= 0 {
return c.readMessage(parent)
}
readCtx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
return c.readMessage(readCtx)
}
func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
if c.ws == nil {
return nil, errOpenAIWSConnClosed
}
if readCtx == nil {
readCtx = context.Background()
}
payload, err := c.ws.ReadMessage(readCtx)
if err != nil {
return nil, err
}
c.touch()
return payload, nil
}
func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error {
if c == nil {
return errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return errOpenAIWSConnClosed
default:
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.ws == nil {
return errOpenAIWSConnClosed
}
if timeout <= 0 {
timeout = openAIWSConnHealthCheckTO
}
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := c.ws.Ping(pingCtx); err != nil {
return err
}
return nil
}
func (c *openAIWSConn) touch() {
if c == nil {
return
}
c.lastUsedNano.Store(time.Now().UnixNano())
}
func (c *openAIWSConn) createdAt() time.Time {
if c == nil {
return time.Time{}
}
nano := c.createdAtNano.Load()
if nano <= 0 {
return time.Time{}
}
return time.Unix(0, nano)
}
func (c *openAIWSConn) lastUsedAt() time.Time {
if c == nil {
return time.Time{}
}
nano := c.lastUsedNano.Load()
if nano <= 0 {
return time.Time{}
}
return time.Unix(0, nano)
}
func (c *openAIWSConn) idleDuration(now time.Time) time.Duration {
if c == nil {
return 0
}
last := c.lastUsedAt()
if last.IsZero() {
return 0
}
return now.Sub(last)
}
func (c *openAIWSConn) age(now time.Time) time.Duration {
if c == nil {
return 0
}
created := c.createdAt()
if created.IsZero() {
return 0
}
return now.Sub(created)
}
func (c *openAIWSConn) isLeased() bool {
if c == nil {
return false
}
return len(c.leaseCh) == 0
}
func (c *openAIWSConn) handshakeHeader(name string) string {
if c == nil || c.handshakeHeaders == nil {
return ""
}
return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name)))
}
func (c *openAIWSConn) isPrewarmed() bool {
if c == nil {
return false
}
return c.prewarmed.Load()
}
func (c *openAIWSConn) markPrewarmed() {
if c == nil {
return
}
c.prewarmed.Store(true)
}
type openAIWSAccountPool struct {
mu sync.Mutex
conns map[string]*openAIWSConn
pinnedConns map[string]int
creating int
lastCleanupAt time.Time
lastAcquire *openAIWSAcquireRequest
prewarmActive bool
prewarmUntil time.Time
prewarmFails int
prewarmFailAt time.Time
}
type OpenAIWSPoolMetricsSnapshot struct {
AcquireTotal int64
AcquireReuseTotal int64
AcquireCreateTotal int64
AcquireQueueWaitTotal int64
AcquireQueueWaitMsTotal int64
ConnPickTotal int64
ConnPickMsTotal int64
ScaleUpTotal int64
ScaleDownTotal int64
}
type openAIWSPoolMetrics struct {
acquireTotal atomic.Int64
acquireReuseTotal atomic.Int64
acquireCreateTotal atomic.Int64
acquireQueueWaitTotal atomic.Int64
acquireQueueWaitMs atomic.Int64
connPickTotal atomic.Int64
connPickMs atomic.Int64
scaleUpTotal atomic.Int64
scaleDownTotal atomic.Int64
}
type openAIWSConnPool struct {
cfg *config.Config
// 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
clientDialer openAIWSClientDialer
accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool
seq atomic.Uint64
metrics openAIWSPoolMetrics
workerStopCh chan struct{}
workerWg sync.WaitGroup
closeOnce sync.Once
}
func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool {
pool := &openAIWSConnPool{
cfg: cfg,
clientDialer: newDefaultOpenAIWSClientDialer(),
workerStopCh: make(chan struct{}),
}
pool.startBackgroundWorkers()
return pool
}
func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot {
if p == nil {
return OpenAIWSPoolMetricsSnapshot{}
}
return OpenAIWSPoolMetricsSnapshot{
AcquireTotal: p.metrics.acquireTotal.Load(),
AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(),
AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(),
AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(),
AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(),
ConnPickTotal: p.metrics.connPickTotal.Load(),
ConnPickMsTotal: p.metrics.connPickMs.Load(),
ScaleUpTotal: p.metrics.scaleUpTotal.Load(),
ScaleDownTotal: p.metrics.scaleDownTotal.Load(),
}
}
func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
if p == nil {
return OpenAIWSTransportMetricsSnapshot{}
}
if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok {
return dialer.SnapshotTransportMetrics()
}
return OpenAIWSTransportMetricsSnapshot{}
}
func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) {
if p == nil || dialer == nil {
return
}
p.clientDialer = dialer
}
// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
func (p *openAIWSConnPool) Close() {
if p == nil {
return
}
p.closeOnce.Do(func() {
if p.workerStopCh != nil {
close(p.workerStopCh)
}
p.workerWg.Wait()
// 遍历所有账户池,关闭全部空闲连接。
p.accounts.Range(func(key, value any) bool {
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
ap.mu.Lock()
for _, conn := range ap.conns {
if conn != nil && !conn.isLeased() {
conn.close()
}
}
ap.mu.Unlock()
return true
})
})
}
func (p *openAIWSConnPool) startBackgroundWorkers() {
if p == nil || p.workerStopCh == nil {
return
}
p.workerWg.Add(2)
go func() {
defer p.workerWg.Done()
p.runBackgroundPingWorker()
}()
go func() {
defer p.workerWg.Done()
p.runBackgroundCleanupWorker()
}()
}
type openAIWSIdlePingCandidate struct {
accountID int64
conn *openAIWSConn
}
func (p *openAIWSConnPool) runBackgroundPingWorker() {
if p == nil {
return
}
ticker := time.NewTicker(openAIWSBackgroundPingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.runBackgroundPingSweep()
case <-p.workerStopCh:
return
}
}
}
func (p *openAIWSConnPool) runBackgroundPingSweep() {
if p == nil {
return
}
candidates := p.snapshotIdleConnsForPing()
var g errgroup.Group
g.SetLimit(10)
for _, item := range candidates {
item := item
if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 {
continue
}
g.Go(func() error {
if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
p.evictConn(item.accountID, item.conn.id)
}
return nil
})
}
_ = g.Wait()
}
func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate {
if p == nil {
return nil
}
candidates := make([]openAIWSIdlePingCandidate, 0)
p.accounts.Range(func(key, value any) bool {
accountID, ok := key.(int64)
if !ok || accountID <= 0 {
return true
}
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
ap.mu.Lock()
for _, conn := range ap.conns {
if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 {
continue
}
candidates = append(candidates, openAIWSIdlePingCandidate{
accountID: accountID,
conn: conn,
})
}
ap.mu.Unlock()
return true
})
return candidates
}
func (p *openAIWSConnPool) runBackgroundCleanupWorker() {
if p == nil {
return
}
ticker := time.NewTicker(openAIWSBackgroundSweepTicker)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.runBackgroundCleanupSweep(time.Now())
case <-p.workerStopCh:
return
}
}
}
func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) {
if p == nil {
return
}
type cleanupResult struct {
evicted []*openAIWSConn
}
results := make([]cleanupResult, 0)
p.accounts.Range(func(_ any, value any) bool {
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
maxConns := p.maxConnsHardCap()
ap.mu.Lock()
if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
}
evicted := p.cleanupAccountLocked(ap, now, maxConns)
ap.lastCleanupAt = now
ap.mu.Unlock()
if len(evicted) > 0 {
results = append(results, cleanupResult{evicted: evicted})
}
return true
})
for _, result := range results {
closeOpenAIWSConns(result.evicted)
}
}
func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) {
if p != nil {
p.metrics.acquireTotal.Add(1)
}
return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0)
}
func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) {
if p == nil || req.Account == nil || req.Account.ID <= 0 {
return nil, errors.New("invalid ws acquire request")
}
if stringsTrim(req.WSURL) == "" {
return nil, errors.New("ws url is empty")
}
accountID := req.Account.ID
effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account)
if effectiveMaxConns <= 0 {
return nil, errOpenAIWSConnQueueFull
}
var evicted []*openAIWSConn
ap := p.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req)
now := time.Now()
if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
ap.lastCleanupAt = now
}
pickStartedAt := time.Now()
allowReuse := !req.ForceNewConn
preferredConnID := stringsTrim(req.PreferredConnID)
forcePreferredConn := allowReuse && req.ForcePreferredConn
if allowReuse {
if forcePreferredConn {
if preferredConnID == "" {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSPreferredConnUnavailable
}
preferredConn, ok := ap.conns[preferredConnID]
if !ok || preferredConn == nil {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSPreferredConnUnavailable
}
if preferredConn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(preferredConn) {
if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
preferredConn.close()
p.evictConn(accountID, preferredConn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{
pool: p,
accountID: accountID,
conn: preferredConn,
connPick: connPick,
reused: true,
}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
preferredConn.waiters.Add(1)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
defer preferredConn.waiters.Add(-1)
waitStart := time.Now()
p.metrics.acquireQueueWaitTotal.Add(1)
if err := preferredConn.acquire(ctx); err != nil {
if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
if p.shouldHealthCheckConn(preferredConn) {
if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
preferredConn.release()
preferredConn.close()
p.evictConn(accountID, preferredConn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
queueWait := time.Since(waitStart)
p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
lease := &openAIWSConnLease{
pool: p,
accountID: accountID,
conn: preferredConn,
queueWait: queueWait,
connPick: connPick,
reused: true,
}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
if preferredConnID != "" {
if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(conn) {
if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
}
best := p.pickLeastBusyConnLocked(ap, "")
if best != nil && best.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(best) {
if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
best.close()
p.evictConn(accountID, best.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
for _, conn := range ap.conns {
if conn == nil || conn == best {
continue
}
if conn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(conn) {
if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
}
}
if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns {
if idle := p.pickOldestIdleConnLocked(ap); idle != nil {
delete(ap.conns, idle.id)
evicted = append(evicted, idle)
p.metrics.scaleDownTotal.Add(1)
}
}
if len(ap.conns)+ap.creating < effectiveMaxConns {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.creating++
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
conn, dialErr := p.dialConn(ctx, req)
ap = p.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.creating--
if dialErr != nil {
ap.prewarmFails++
ap.prewarmFailAt = time.Now()
ap.mu.Unlock()
return nil, dialErr
}
ap.conns[conn.id] = conn
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
ap.mu.Unlock()
p.metrics.acquireCreateTotal.Add(1)
if !conn.tryAcquire() {
if err := conn.acquire(ctx); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick}
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
if req.ForceNewConn {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID)
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
if target == nil {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnClosed
}
if int(target.waiters.Load()) >= p.queueLimitPerConn() {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
target.waiters.Add(1)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
defer target.waiters.Add(-1)
waitStart := time.Now()
p.metrics.acquireQueueWaitTotal.Add(1)
if err := target.acquire(ctx); err != nil {
if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
if p.shouldHealthCheckConn(target) {
if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
target.release()
target.close()
p.evictConn(accountID, target.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
queueWait := time.Since(waitStart)
p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) {
if p == nil {
return
}
if duration < 0 {
duration = 0
}
p.metrics.connPickTotal.Add(1)
p.metrics.connPickMs.Add(duration.Milliseconds())
}
func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn {
if ap == nil || len(ap.conns) == 0 {
return nil
}
var oldest *openAIWSConn
for _, conn := range ap.conns {
if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
continue
}
if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) {
oldest = conn
}
}
return oldest
}
func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool {
if p == nil || accountID <= 0 {
return nil
}
if existing, ok := p.accounts.Load(accountID); ok {
if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil {
return ap
}
}
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
pinnedConns: make(map[string]int),
}
actual, _ := p.accounts.LoadOrStore(accountID, ap)
if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil {
return typed
}
return ap
}
// ensureAccountPoolLocked 兼容旧调用。
func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool {
return p.getOrCreateAccountPool(accountID)
}
func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) {
if p == nil || accountID <= 0 {
return nil, false
}
value, ok := p.accounts.Load(accountID)
if !ok || value == nil {
return nil, false
}
ap, typed := value.(*openAIWSAccountPool)
return ap, typed && ap != nil
}
func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool {
if ap == nil || connID == "" || len(ap.pinnedConns) == 0 {
return false
}
return ap.pinnedConns[connID] > 0
}
func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn {
if ap == nil {
return nil
}
maxAge := p.maxConnAge()
evicted := make([]*openAIWSConn, 0)
for id, conn := range ap.conns {
if conn == nil {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
continue
}
select {
case <-conn.closedCh:
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
evicted = append(evicted, conn)
continue
default:
}
if p.isConnPinnedLocked(ap, id) {
continue
}
if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
evicted = append(evicted, conn)
}
}
if maxConns <= 0 {
maxConns = p.maxConnsHardCap()
}
maxIdle := p.maxIdlePerAccount()
if maxIdle < 0 || maxIdle > maxConns {
maxIdle = maxConns
}
if maxIdle >= 0 && len(ap.conns) > maxIdle {
idleConns := make([]*openAIWSConn, 0, len(ap.conns))
for id, conn := range ap.conns {
if conn == nil {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
continue
}
// 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
continue
}
idleConns = append(idleConns, conn)
}
sort.SliceStable(idleConns, func(i, j int) bool {
return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt())
})
redundant := len(ap.conns) - maxIdle
if redundant > len(idleConns) {
redundant = len(idleConns)
}
for i := 0; i < redundant; i++ {
conn := idleConns[i]
delete(ap.conns, conn.id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, conn.id)
}
evicted = append(evicted, conn)
}
if redundant > 0 {
p.metrics.scaleDownTotal.Add(int64(redundant))
}
}
return evicted
}
func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn {
if ap == nil || len(ap.conns) == 0 {
return nil
}
preferredConnID = stringsTrim(preferredConnID)
if preferredConnID != "" {
if conn, ok := ap.conns[preferredConnID]; ok {
return conn
}
}
var best *openAIWSConn
var bestWaiters int32
var bestLastUsed time.Time
for _, conn := range ap.conns {
if conn == nil {
continue
}
waiters := conn.waiters.Load()
lastUsed := conn.lastUsedAt()
if best == nil ||
waiters < bestWaiters ||
(waiters == bestWaiters && lastUsed.Before(bestLastUsed)) {
best = conn
bestWaiters = waiters
bestLastUsed = lastUsed
}
}
return best
}
func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) {
if ap == nil {
return 0, 0
}
for _, conn := range ap.conns {
if conn == nil {
continue
}
if conn.isLeased() {
inflight++
}
waiters += int(conn.waiters.Load())
}
return inflight, waiters
}
// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) {
if p == nil || accountID <= 0 {
return 0, 0, 0
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return 0, 0, 0
}
ap.mu.Lock()
defer ap.mu.Unlock()
inflight, waiters = accountPoolLoadLocked(ap)
return inflight, waiters, len(ap.conns)
}
func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) {
if p == nil || accountID <= 0 {
return
}
var req openAIWSAcquireRequest
need := 0
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return
}
ap.mu.Lock()
defer ap.mu.Unlock()
if ap.lastAcquire == nil {
return
}
if ap.prewarmActive {
return
}
now := time.Now()
if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) {
return
}
if p.shouldSuppressPrewarmLocked(ap, now) {
return
}
effectiveMaxConns := p.maxConnsHardCap()
if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
}
target := p.targetConnCountLocked(ap, effectiveMaxConns)
current := len(ap.conns) + ap.creating
if current >= target {
return
}
need = target - current
if need <= 0 {
return
}
req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire)
ap.prewarmActive = true
if cooldown := p.prewarmCooldown(); cooldown > 0 {
ap.prewarmUntil = now.Add(cooldown)
}
ap.creating += need
p.metrics.scaleUpTotal.Add(int64(need))
go p.prewarmConns(accountID, req, need)
}
func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int {
if ap == nil {
return 0
}
if maxConns <= 0 {
return 0
}
minIdle := p.minIdlePerAccount()
if minIdle < 0 {
minIdle = 0
}
if minIdle > maxConns {
minIdle = maxConns
}
inflight, waiters := accountPoolLoadLocked(ap)
utilization := p.targetUtilization()
demand := inflight + waiters
if demand <= 0 {
return minIdle
}
target := 1
if demand > 1 {
target = int(math.Ceil(float64(demand) / utilization))
}
if waiters > 0 && target < len(ap.conns)+1 {
target = len(ap.conns) + 1
}
if target < minIdle {
target = minIdle
}
if target > maxConns {
target = maxConns
}
return target
}
func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) {
defer func() {
if ap, ok := p.getAccountPool(accountID); ok && ap != nil {
ap.mu.Lock()
ap.prewarmActive = false
ap.mu.Unlock()
}
}()
for i := 0; i < total; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay)
conn, err := p.dialConn(ctx, req)
cancel()
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
if conn != nil {
conn.close()
}
return
}
ap.mu.Lock()
if ap.creating > 0 {
ap.creating--
}
if err != nil {
ap.prewarmFails++
ap.prewarmFailAt = time.Now()
ap.mu.Unlock()
continue
}
if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) {
ap.mu.Unlock()
conn.close()
continue
}
ap.conns[conn.id] = conn
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
ap.mu.Unlock()
}
}
func (p *openAIWSConnPool) evictConn(accountID int64, connID string) {
if p == nil || accountID <= 0 || stringsTrim(connID) == "" {
return
}
var conn *openAIWSConn
ap, ok := p.getAccountPool(accountID)
if ok && ap != nil {
ap.mu.Lock()
if c, exists := ap.conns[connID]; exists {
conn = c
delete(ap.conns, connID)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, connID)
}
}
ap.mu.Unlock()
}
if conn != nil {
conn.close()
}
}
func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool {
if p == nil || accountID <= 0 {
return false
}
connID = stringsTrim(connID)
if connID == "" {
return false
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
if _, exists := ap.conns[connID]; !exists {
return false
}
if ap.pinnedConns == nil {
ap.pinnedConns = make(map[string]int)
}
ap.pinnedConns[connID]++
return true
}
func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) {
if p == nil || accountID <= 0 {
return
}
connID = stringsTrim(connID)
if connID == "" {
return
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return
}
ap.mu.Lock()
defer ap.mu.Unlock()
if len(ap.pinnedConns) == 0 {
return
}
count := ap.pinnedConns[connID]
if count <= 1 {
delete(ap.pinnedConns, connID)
return
}
ap.pinnedConns[connID] = count - 1
}
func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) {
if p == nil || p.clientDialer == nil {
return nil, errors.New("openai ws client dialer is nil")
}
conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
if err != nil {
return nil, &openAIWSDialError{
StatusCode: status,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if conn == nil {
return nil, &openAIWSDialError{
StatusCode: status,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: errors.New("openai ws dialer returned nil connection"),
}
}
id := p.nextConnID(req.Account.ID)
return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil
}
func (p *openAIWSConnPool) nextConnID(accountID int64) string {
seq := p.seq.Add(1)
buf := make([]byte, 0, 32)
buf = append(buf, "oa_ws_"...)
buf = strconv.AppendInt(buf, accountID, 10)
buf = append(buf, '_')
buf = strconv.AppendUint(buf, seq, 10)
return string(buf)
}
func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool {
if conn == nil {
return false
}
return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle
}
func (p *openAIWSConnPool) maxConnsHardCap() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
}
return 8
}
func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool {
if p != nil && p.cfg != nil {
return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled
}
return false
}
func (p *openAIWSConnPool) modeRouterV2Enabled() bool {
if p != nil && p.cfg != nil {
return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
}
return false
}
func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 {
if p == nil || p.cfg == nil || account == nil {
return 1.0
}
switch account.Type {
case AccountTypeOAuth:
if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 {
return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor
}
case AccountTypeAPIKey:
if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 {
return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor
}
}
return 1.0
}
func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int {
hardCap := p.maxConnsHardCap()
if hardCap <= 0 {
return 0
}
if p.modeRouterV2Enabled() {
if account == nil {
return hardCap
}
if account.Concurrency <= 0 {
return 0
}
return account.Concurrency
}
if account == nil || !p.dynamicMaxConnsEnabled() {
return hardCap
}
if account.Concurrency <= 0 {
// 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
return hardCap
}
factor := p.maxConnsFactorByAccount(account)
if factor <= 0 {
factor = 1.0
}
effective := int(math.Ceil(float64(account.Concurrency) * factor))
if effective < 1 {
effective = 1
}
if effective > hardCap {
effective = hardCap
}
return effective
}
func (p *openAIWSConnPool) minIdlePerAccount() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 {
return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount
}
return 0
}
func (p *openAIWSConnPool) maxIdlePerAccount() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 {
return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount
}
return 4
}
func (p *openAIWSConnPool) maxConnAge() time.Duration {
return openAIWSConnMaxAge
}
func (p *openAIWSConnPool) queueLimitPerConn() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 {
return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn
}
return 256
}
func (p *openAIWSConnPool) targetUtilization() float64 {
if p != nil && p.cfg != nil {
ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization
if ratio > 0 && ratio <= 1 {
return ratio
}
}
return 0.7
}
func (p *openAIWSConnPool) prewarmCooldown() time.Duration {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 {
return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond
}
return 0
}
func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool {
if ap == nil {
return true
}
if ap.prewarmFails <= 0 {
return false
}
if ap.prewarmFailAt.IsZero() {
ap.prewarmFails = 0
return false
}
if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow {
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
return false
}
return ap.prewarmFails >= openAIWSPrewarmFailureSuppress
}
func (p *openAIWSConnPool) dialTimeout() time.Duration {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
}
return 10 * time.Second
}
func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest {
copied := req
copied.Headers = cloneHeader(req.Headers)
copied.WSURL = stringsTrim(req.WSURL)
copied.ProxyURL = stringsTrim(req.ProxyURL)
copied.PreferredConnID = stringsTrim(req.PreferredConnID)
return copied
}
func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest {
if req == nil {
return nil
}
copied := cloneOpenAIWSAcquireRequest(*req)
return &copied
}
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for k, vals := range src {
if len(vals) == 0 {
dst[k] = nil
continue
}
copied := make([]string, len(vals))
copy(copied, vals)
dst[k] = copied
}
return dst
}
func closeOpenAIWSConns(conns []*openAIWSConn) {
if len(conns) == 0 {
return
}
for _, conn := range conns {
if conn == nil {
continue
}
conn.close()
}
}
func stringsTrim(value string) string {
return strings.TrimSpace(value)
}
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSCountingDialer{})
account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
req := openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ctx := context.Background()
lease, err := pool.Acquire(ctx, req)
if err != nil {
b.Fatalf("warm acquire failed: %v", err)
}
lease.Release()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var (
got *openAIWSConnLease
acquireErr error
)
for retry := 0; retry < 3; retry++ {
got, acquireErr = pool.Acquire(ctx, req)
if acquireErr == nil {
break
}
if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
break
}
}
if acquireErr != nil {
b.Fatalf("acquire failed: %v", acquireErr)
}
got.Release()
}
})
}
package service
import (
"context"
"errors"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
pool := newOpenAIWSConnPool(cfg)
accountID := int64(10)
ap := pool.getOrCreateAccountPool(accountID)
stale := newOpenAIWSConn("stale", accountID, nil, nil)
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil)
idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil)
idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
ap.conns[stale.id] = stale
ap.conns[idleOld.id] = idleOld
ap.conns[idleNew.id] = idleNew
evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
closeOpenAIWSConns(evicted)
require.Nil(t, ap.conns["stale"], "stale connection should be rotated")
require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle")
require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept")
}
func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) {
pool := newOpenAIWSConnPool(&config.Config{})
id1 := pool.nextConnID(42)
id2 := pool.nextConnID(42)
require.True(t, strings.HasPrefix(id1, "oa_ws_42_"))
require.True(t, strings.HasPrefix(id2, "oa_ws_42_"))
require.NotEqual(t, id1, id2)
require.Equal(t, "oa_ws_42_1", id1)
require.Equal(t, "oa_ws_42_2", id2)
}
func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) {
require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval)
require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker)
}
func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) {
conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0))
var nilLease *openAIWSConnLease
err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) {
probe := &openAIWSContextProbeConn{}
conn := newOpenAIWSConn("ctx_probe", 1, probe, nil)
require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0))
require.NotNil(t, probe.lastWriteCtx)
}
func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5
pool := newOpenAIWSConnPool(cfg)
ap := pool.getOrCreateAccountPool(88)
conn1 := newOpenAIWSConn("c1", 88, nil, nil)
conn2 := newOpenAIWSConn("c2", 88, nil, nil)
require.True(t, conn1.tryAcquire())
require.True(t, conn2.tryAcquire())
conn1.waiters.Store(1)
conn2.waiters.Store(1)
ap.conns[conn1.id] = conn1
ap.conns[conn2.id] = conn2
target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限")
conn1.release()
conn2.release()
conn1.waiters.Store(0)
conn2.waiters.Store(0)
target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接")
}
func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
pool := newOpenAIWSConnPool(cfg)
ap := pool.getOrCreateAccountPool(66)
target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0")
}
func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSFakeDialer{})
accountID := int64(77)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return len(ap.conns) >= 2
}, 2*time.Second, 20*time.Millisecond)
metrics := pool.SnapshotMetrics()
require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2))
}
func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500
pool := newOpenAIWSConnPool(cfg)
dialer := &openAIWSCountingDialer{}
pool.setClientDialerForTest(dialer)
accountID := int64(178)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return len(ap.conns) >= 2 && !ap.prewarmActive
}, 2*time.Second, 20*time.Millisecond)
firstDialCount := dialer.DialCount()
require.GreaterOrEqual(t, firstDialCount, 2)
// 人工制造缺口触发新一轮预热需求。
ap, ok := pool.getAccountPool(accountID)
require.True(t, ok)
require.NotNil(t, ap)
ap.mu.Lock()
for id := range ap.conns {
delete(ap.conns, id)
break
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
time.Sleep(120 * time.Millisecond)
require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热")
time.Sleep(450 * time.Millisecond)
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
return dialer.DialCount() > firstDialCount
}, 2*time.Second, 20*time.Millisecond)
}
func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0
pool := newOpenAIWSConnPool(cfg)
dialer := &openAIWSAlwaysFailDialer{}
pool.setClientDialerForTest(dialer)
accountID := int64(279)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return !ap.prewarmActive
}, 2*time.Second, 20*time.Millisecond)
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return !ap.prewarmActive
}, 2*time.Second, 20*time.Millisecond)
require.Equal(t, 2, dialer.DialCount())
// 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
pool.ensureTargetIdleAsync(accountID)
time.Sleep(120 * time.Millisecond)
require.Equal(t, 2, dialer.DialCount())
}
func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
pool := newOpenAIWSConnPool(cfg)
accountID := int64(99)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil)
require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队
ap := pool.ensureAccountPoolLocked(accountID)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
go func() {
time.Sleep(60 * time.Millisecond)
conn.release()
}()
lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.NoError(t, err)
require.NotNil(t, lease)
require.True(t, lease.Reused())
require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond)
lease.Release()
metrics := pool.SnapshotMetrics()
require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1))
require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0))
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
}
func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
pool := newOpenAIWSConnPool(cfg)
dialer := &openAIWSCountingDialer{}
pool.setClientDialerForTest(dialer)
account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.NoError(t, err)
require.NotNil(t, lease1)
lease1.Release()
lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
ForceNewConn: true,
})
require.NoError(t, err)
require.NotNil(t, lease2)
lease2.Release()
require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接")
}
func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(account.ID)
otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[otherConn.id] = otherConn
ap.mu.Unlock()
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
ForcePreferredConn: true,
})
require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: "missing_conn",
ForcePreferredConn: true,
})
require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
}
func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(account.ID)
preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil)
otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil)
require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取")
ap.mu.Lock()
ap.conns[preferredConn.id] = preferredConn
ap.conns[otherConn.id] = otherConn
ap.lastCleanupAt = time.Now()
ap.mu.Unlock()
go func() {
time.Sleep(60 * time.Millisecond)
preferredConn.release()
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: preferredConn.id,
ForcePreferredConn: true,
})
require.NoError(t, err)
require.NotNil(t, lease)
require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移")
require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond)
lease.Release()
require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占")
otherConn.release()
}
func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(account.ID)
preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[preferredConn.id] = preferredConn
ap.conns[otherConn.id] = otherConn
ap.lastCleanupAt = time.Now()
ap.mu.Unlock()
lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: preferredConn.id,
ForcePreferredConn: true,
})
require.NoError(t, err)
require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中")
lease.Release()
require.True(t, preferredConn.tryAcquire())
preferredConn.waiters.Store(1)
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: preferredConn.id,
ForcePreferredConn: true,
})
require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移")
preferredConn.waiters.Store(0)
preferredConn.release()
}
func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
pool := newOpenAIWSConnPool(cfg)
accountID := int64(126)
ap := pool.getOrCreateAccountPool(accountID)
pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil)
idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[pinnedConn.id] = pinnedConn
ap.conns[idleConn.id] = idleConn
ap.mu.Unlock()
require.True(t, pool.PinConn(accountID, pinnedConn.id))
evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
closeOpenAIWSConns(evicted)
ap.mu.Lock()
_, pinnedExists := ap.conns[pinnedConn.id]
_, idleExists := ap.conns[idleConn.id]
ap.mu.Unlock()
require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收")
require.False(t, idleExists, "非绑定的空闲连接应被回收")
pool.UnpinConn(accountID, pinnedConn.id)
evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
closeOpenAIWSConns(evicted)
ap.mu.Lock()
_, pinnedExists = ap.conns[pinnedConn.id]
ap.mu.Unlock()
require.False(t, pinnedExists, "解绑后连接应可被正常回收")
}
func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) {
var nilPool *openAIWSConnPool
require.False(t, nilPool.PinConn(1, "x"))
nilPool.UnpinConn(1, "x")
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
accountID := int64(128)
ap := &openAIWSAccountPool{
conns: map[string]*openAIWSConn{},
}
pool.accounts.Store(accountID, ap)
require.False(t, pool.PinConn(0, "x"))
require.False(t, pool.PinConn(999, "x"))
require.False(t, pool.PinConn(accountID, ""))
require.False(t, pool.PinConn(accountID, "missing"))
conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
require.True(t, pool.PinConn(accountID, conn.id))
require.True(t, pool.PinConn(accountID, conn.id))
ap.mu.Lock()
require.Equal(t, 2, ap.pinnedConns[conn.id])
ap.mu.Unlock()
pool.UnpinConn(accountID, conn.id)
ap.mu.Lock()
require.Equal(t, 1, ap.pinnedConns[conn.id])
ap.mu.Unlock()
pool.UnpinConn(accountID, conn.id)
ap.mu.Lock()
_, exists := ap.pinnedConns[conn.id]
ap.mu.Unlock()
require.False(t, exists)
pool.UnpinConn(accountID, conn.id)
pool.UnpinConn(accountID, "")
pool.UnpinConn(0, conn.id)
pool.UnpinConn(999, conn.id)
}
func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
pool := newOpenAIWSConnPool(cfg)
oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10}
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束")
oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3}
require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow))
apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10}
require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放")
apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1}
require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1")
unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限")
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限")
}
func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0
pool := newOpenAIWSConnPool(cfg)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2}
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为")
}
func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
pool := newOpenAIWSConnPool(cfg)
high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20}
require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限")
nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0}
require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度")
}
func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
}
func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) {
conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil)
lease := &openAIWSConnLease{conn: conn}
_, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond)
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond)
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
parentCtx, cancel := context.WithCancel(context.Background())
cancel()
_, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}
func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) {
conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
parentCtx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
start := time.Now()
err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute)
elapsed := time.Since(start)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
require.Less(t, elapsed, 200*time.Millisecond)
}
func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) {
conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
var nilLease *openAIWSConnLease
err := nilLease.PingWithTimeout(50 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) {
conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil)
readDone := make(chan error, 1)
go func() {
_, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond)
readDone <- err
}()
// 让读取先占用 readMu。
time.Sleep(20 * time.Millisecond)
start := time.Now()
err := conn.pingWithTimeout(50 * time.Millisecond)
elapsed := time.Since(start)
require.NoError(t, err)
require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞")
require.NoError(t, <-readDone)
}
func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
pool := newOpenAIWSConnPool(cfg)
accountID := int64(301)
ap := pool.getOrCreateAccountPool(accountID)
conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
pool.runBackgroundPingSweep()
ap.mu.Lock()
_, exists := ap.conns[conn.id]
ap.mu.Unlock()
require.False(t, exists, "后台 ping 失败的空闲连接应被回收")
}
func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
pool := newOpenAIWSConnPool(cfg)
accountID := int64(302)
ap := pool.getOrCreateAccountPool(accountID)
stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil)
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
ap.mu.Lock()
ap.conns[stale.id] = stale
ap.mu.Unlock()
pool.runBackgroundCleanupSweep(time.Now())
ap.mu.Lock()
_, exists := ap.conns[stale.id]
ap.mu.Unlock()
require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接")
}
func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) {
var nilPool *openAIWSConnPool
require.NotPanics(t, func() {
nilPool.startBackgroundWorkers()
nilPool.runBackgroundPingWorker()
nilPool.runBackgroundPingSweep()
_ = nilPool.snapshotIdleConnsForPing()
nilPool.runBackgroundCleanupWorker()
nilPool.runBackgroundCleanupSweep(time.Now())
})
poolNoStop := &openAIWSConnPool{}
require.NotPanics(t, func() {
poolNoStop.startBackgroundWorkers()
})
poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})}
pingDone := make(chan struct{})
go func() {
poolStopPing.runBackgroundPingWorker()
close(pingDone)
}()
close(poolStopPing.workerStopCh)
select {
case <-pingDone:
case <-time.After(500 * time.Millisecond):
t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出")
}
poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})}
cleanupDone := make(chan struct{})
go func() {
poolStopCleanup.runBackgroundCleanupWorker()
close(cleanupDone)
}()
close(poolStopCleanup.workerStopCh)
select {
case <-cleanupDone:
case <-time.After(500 * time.Millisecond):
t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出")
}
}
func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) {
pool := &openAIWSConnPool{}
pool.accounts.Store("invalid-key", &openAIWSAccountPool{})
pool.accounts.Store(int64(123), "invalid-value")
accountID := int64(123)
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
}
ap.conns["nil_conn"] = nil
leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
require.True(t, leased.tryAcquire())
ap.conns[leased.id] = leased
waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil)
waiting.waiters.Store(1)
ap.conns[waiting.id] = waiting
idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil)
ap.conns[idle.id] = idle
pool.accounts.Store(accountID, ap)
candidates := pool.snapshotIdleConnsForPing()
require.Len(t, candidates, 1)
require.Equal(t, idle.id, candidates[0].conn.id)
}
func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
pool := &openAIWSConnPool{cfg: cfg}
pool.accounts.Store("bad-key", "bad-value")
accountID := int64(2026)
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
}
ap.conns["nil_conn"] = nil
stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil)
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
ap.conns[stale.id] = stale
ap.lastAcquire = &openAIWSAcquireRequest{
Account: &Account{
ID: accountID,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
},
}
pool.accounts.Store(accountID, ap)
now := time.Now()
require.NotPanics(t, func() {
pool.runBackgroundCleanupSweep(now)
})
ap.mu.Lock()
_, nilConnExists := ap.conns["nil_conn"]
_, exists := ap.conns[stale.id]
lastCleanupAt := ap.lastCleanupAt
ap.mu.Unlock()
require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目")
require.False(t, exists, "后台清理应清理过期连接")
require.Equal(t, now, lastCleanupAt)
}
func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) {
var nilPool *openAIWSConnPool
require.Equal(t, 256, nilPool.queueLimitPerConn())
pool := &openAIWSConnPool{cfg: &config.Config{}}
require.Equal(t, 256, pool.queueLimitPerConn())
pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9
require.Equal(t, 9, pool.queueLimitPerConn())
}
func TestOpenAIWSConnPool_Close(t *testing.T) {
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
// Close 应该可以安全调用
pool.Close()
// workerStopCh 应已关闭
select {
case <-pool.workerStopCh:
// 预期:channel 已关闭
default:
t.Fatal("Close 后 workerStopCh 应已关闭")
}
// 多次调用 Close 不应 panic
pool.Close()
// nil pool 调用 Close 不应 panic
var nilPool *openAIWSConnPool
nilPool.Close()
}
func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) {
baseErr := errors.New("boom")
dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr}
require.Contains(t, dialErr.Error(), "status=502")
require.ErrorIs(t, dialErr.Unwrap(), baseErr)
noStatus := &openAIWSDialError{Err: baseErr}
require.Contains(t, noStatus.Error(), "boom")
var nilDialErr *openAIWSDialError
require.Equal(t, "", nilDialErr.Error())
require.NoError(t, nilDialErr.Unwrap())
}
func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) {
conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{
"X-Test": []string{" value "},
})
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"}))
payload, err := lease.ReadMessage(100 * time.Millisecond)
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
payload, err = lease.ReadMessageContext(context.Background())
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
payload, err = conn.readMessageWithTimeout(100 * time.Millisecond)
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
require.Equal(t, "value", conn.handshakeHeader(" X-Test "))
require.NotZero(t, conn.createdAt())
require.NotZero(t, conn.lastUsedAt())
require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0))
require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0))
require.False(t, conn.isLeased())
// 覆盖空上下文路径
_, err = conn.readMessage(context.Background())
require.NoError(t, err)
// 覆盖 nil 保护分支
var nilConn *openAIWSConn
require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed)
_, err = nilConn.readMessageWithTimeout(10 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) {
pool := &openAIWSConnPool{}
accountID := int64(404)
ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil)
idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil)
idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
require.True(t, leased.tryAcquire())
leased.waiters.Store(2)
ap.conns[idleOld.id] = idleOld
ap.conns[idleNew.id] = idleNew
ap.conns[leased.id] = leased
oldest := pool.pickOldestIdleConnLocked(ap)
require.NotNil(t, oldest)
require.Equal(t, idleOld.id, oldest.id)
inflight, waiters := accountPoolLoadLocked(ap)
require.Equal(t, 1, inflight)
require.Equal(t, 2, waiters)
pool.accounts.Store(accountID, ap)
loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID)
require.Equal(t, 1, loadInflight)
require.Equal(t, 2, loadWaiters)
require.Equal(t, 3, conns)
zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0)
require.Equal(t, 0, zeroInflight)
require.Equal(t, 0, zeroWaiters)
require.Equal(t, 0, zeroConns)
}
func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) {
pool := &openAIWSConnPool{}
release := make(chan struct{})
pool.workerWg.Add(1)
go func() {
defer pool.workerWg.Done()
<-release
}()
closed := make(chan struct{})
go func() {
pool.Close()
close(closed)
}()
select {
case <-closed:
t.Fatal("Close 不应在 WaitGroup 未完成时提前返回")
case <-time.After(30 * time.Millisecond):
}
close(release)
select {
case <-closed:
case <-time.After(time.Second):
t.Fatal("Close 未等待 workerWg 完成")
}
}
func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) {
pool := &openAIWSConnPool{
workerStopCh: make(chan struct{}),
}
accountID := int64(606)
ap := &openAIWSAccountPool{
conns: map[string]*openAIWSConn{},
}
idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil)
require.True(t, leased.tryAcquire())
ap.conns[idle.id] = idle
ap.conns[leased.id] = leased
pool.accounts.Store(accountID, ap)
pool.accounts.Store("invalid-key", "invalid-value")
pool.Close()
select {
case <-idle.closedCh:
// idle should be closed
default:
t.Fatal("空闲连接应在 Close 时被关闭")
}
select {
case <-leased.closedCh:
t.Fatal("已租赁连接不应在 Close 时被关闭")
default:
}
leased.release()
pool.Close()
}
func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) {
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
accountID := int64(505)
ap := pool.getOrCreateAccountPool(accountID)
var current atomic.Int32
var maxConcurrent atomic.Int32
release := make(chan struct{})
for i := 0; i < 25; i++ {
conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{
current: &current,
maxConcurrent: &maxConcurrent,
release: release,
}, nil)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
}
done := make(chan struct{})
go func() {
pool.runBackgroundPingSweep()
close(done)
}()
require.Eventually(t, func() bool {
return maxConcurrent.Load() >= 10
}, time.Second, 10*time.Millisecond)
close(release)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("runBackgroundPingSweep 未在释放后完成")
}
require.LessOrEqual(t, maxConcurrent.Load(), int32(10))
}
func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) {
var nilLease *openAIWSConnLease
require.Equal(t, "", nilLease.ConnID())
require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration())
require.Equal(t, time.Duration(0), nilLease.ConnPickDuration())
require.False(t, nilLease.Reused())
require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
require.False(t, nilLease.IsPrewarmed())
nilLease.MarkPrewarmed()
nilLease.Release()
conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}})
lease := &openAIWSConnLease{
conn: conn,
queueWait: 3 * time.Millisecond,
connPick: 4 * time.Millisecond,
reused: true,
}
require.Equal(t, "getter_conn", lease.ConnID())
require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration())
require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration())
require.True(t, lease.Reused())
require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
require.False(t, lease.IsPrewarmed())
lease.MarkPrewarmed()
require.True(t, lease.IsPrewarmed())
lease.Release()
}
func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) {
var nilPool *openAIWSConnPool
require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics())
require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics())
pool := &openAIWSConnPool{cfg: &config.Config{}}
pool.metrics.acquireTotal.Store(7)
pool.metrics.acquireReuseTotal.Store(3)
metrics := pool.SnapshotMetrics()
require.Equal(t, int64(7), metrics.AcquireTotal)
require.Equal(t, int64(3), metrics.AcquireReuseTotal)
// 非 transport metrics dialer 路径
pool.clientDialer = &openAIWSFakeDialer{}
require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics())
pool.setClientDialerForTest(nil)
require.NotNil(t, pool.clientDialer)
require.Equal(t, 8, nilPool.maxConnsHardCap())
require.False(t, nilPool.dynamicMaxConnsEnabled())
require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil))
require.Equal(t, 0, nilPool.minIdlePerAccount())
require.Equal(t, 4, nilPool.maxIdlePerAccount())
require.Equal(t, 256, nilPool.queueLimitPerConn())
require.Equal(t, 0.7, nilPool.targetUtilization())
require.Equal(t, time.Duration(0), nilPool.prewarmCooldown())
require.Equal(t, 10*time.Second, nilPool.dialTimeout())
// shouldSuppressPrewarmLocked 覆盖 3 条分支
now := time.Now()
apNilFail := &openAIWSAccountPool{prewarmFails: 1}
require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now))
apZeroTime := &openAIWSAccountPool{prewarmFails: 2}
require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now))
require.Equal(t, 0, apZeroTime.prewarmFails)
apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)}
require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now))
apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now}
require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now))
// recordConnPickDuration 的保护分支
nilPool.recordConnPickDuration(10 * time.Millisecond)
pool.recordConnPickDuration(-10 * time.Millisecond)
require.Equal(t, int64(1), pool.metrics.connPickTotal.Load())
// account pool 读写分支
require.Nil(t, nilPool.getOrCreateAccountPool(1))
require.Nil(t, pool.getOrCreateAccountPool(0))
pool.accounts.Store(int64(7), "invalid")
ap := pool.getOrCreateAccountPool(7)
require.NotNil(t, ap)
_, ok := pool.getAccountPool(0)
require.False(t, ok)
_, ok = pool.getAccountPool(12345)
require.False(t, ok)
pool.accounts.Store(int64(8), "bad-type")
_, ok = pool.getAccountPool(8)
require.False(t, ok)
// health check 条件
require.False(t, pool.shouldHealthCheckConn(nil))
conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil)
conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano())
require.True(t, pool.shouldHealthCheckConn(conn))
}
func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) {
var nilConn *openAIWSConn
nilConn.touch()
require.Equal(t, time.Time{}, nilConn.createdAt())
require.Equal(t, time.Time{}, nilConn.lastUsedAt())
require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now()))
require.Equal(t, time.Duration(0), nilConn.age(time.Now()))
require.False(t, nilConn.isLeased())
require.False(t, nilConn.isPrewarmed())
nilConn.markPrewarmed()
conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil)
require.True(t, conn.tryAcquire())
require.True(t, conn.isLeased())
conn.release()
require.False(t, conn.isLeased())
conn.close()
require.False(t, conn.tryAcquire())
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := conn.acquire(ctx)
require.Error(t, err)
}
func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) {
lease := &openAIWSConnLease{}
require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
_, err := lease.ReadMessage(10 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageContext(context.Background())
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) {
conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
lease.Release()
lease.Release() // idempotent
require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
_, err := lease.ReadMessage(10 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageContext(context.Background())
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed)
}
func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) {
conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil)
ap := &openAIWSAccountPool{
conns: map[string]*openAIWSConn{
conn.id: conn,
},
}
pool := &openAIWSConnPool{}
pool.accounts.Store(int64(7), ap)
lease := &openAIWSConnLease{
pool: pool,
accountID: 7,
conn: conn,
}
lease.Release()
lease.MarkBroken()
ap.mu.Lock()
_, exists := ap.conns[conn.id]
ap.mu.Unlock()
require.True(t, exists, "released lease should not evict active pool connection")
}
func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) {
var nilConn *openAIWSConn
require.False(t, nilConn.tryAcquire())
require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed)
nilConn.release()
nilConn.close()
require.Equal(t, "", nilConn.handshakeHeader("x-test"))
connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil)
require.True(t, connBusy.tryAcquire())
ctx, cancel := context.WithCancel(context.Background())
cancel()
require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled)
connBusy.release()
connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil)
connClosed.close()
require.ErrorIs(
t,
connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second),
errOpenAIWSConnClosed,
)
_, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil)
require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed)
_, err = connNoWS.readMessage(context.Background())
require.ErrorIs(t, err, errOpenAIWSConnClosed)
require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
require.Equal(t, "", connNoWS.handshakeHeader("x-test"))
connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil)
require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil))
_, err = connOK.readMessageWithContextTimeout(context.Background(), 0)
require.NoError(t, err)
require.NoError(t, connOK.pingWithTimeout(0))
connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil)
connZero.createdAtNano.Store(0)
connZero.lastUsedNano.Store(0)
require.True(t, connZero.createdAt().IsZero())
require.True(t, connZero.lastUsedAt().IsZero())
require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now()))
require.Equal(t, time.Duration(0), connZero.age(time.Now()))
require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil))
copied := cloneHeader(http.Header{
"X-Empty": []string{},
"X-Test": []string{"v1"},
})
require.Contains(t, copied, "X-Empty")
require.Nil(t, copied["X-Empty"])
require.Equal(t, "v1", copied.Get("X-Test"))
closeOpenAIWSConns([]*openAIWSConn{nil, connOK})
}
func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) {
pool := newOpenAIWSConnPool(&config.Config{})
accountID := int64(5001)
conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil)
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
lease := &openAIWSConnLease{
pool: pool,
accountID: accountID,
conn: conn,
}
lease.MarkBroken()
ap.mu.Lock()
_, exists := ap.conns[conn.id]
ap.mu.Unlock()
require.False(t, exists)
require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭")
}
func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
pool := newOpenAIWSConnPool(cfg)
require.Equal(t, 0, pool.targetConnCountLocked(nil, 1))
ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
require.Equal(t, 0, pool.targetConnCountLocked(ap, 0))
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3
require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断")
// 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9
busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil)
require.True(t, busy.tryAcquire())
busy.waiters.Store(1)
ap.conns[busy.id] = busy
target := pool.targetConnCountLocked(ap, 4)
require.GreaterOrEqual(t, target, len(ap.conns)+1)
// prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
req := openAIWSAcquireRequest{
Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
WSURL: "wss://example.com/v1/responses",
}
pool.prewarmConns(999, req, 1)
// prewarm: 拨号失败分支(prewarmFails 累加)
accountID := int64(1000)
failPool := newOpenAIWSConnPool(cfg)
failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{})
apFail := failPool.getOrCreateAccountPool(accountID)
apFail.mu.Lock()
apFail.creating = 1
apFail.mu.Unlock()
req.Account.ID = accountID
failPool.prewarmConns(accountID, req, 1)
apFail.mu.Lock()
require.GreaterOrEqual(t, apFail.prewarmFails, 1)
apFail.mu.Unlock()
}
func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) {
var nilPool *openAIWSConnPool
_, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{})
require.Error(t, err)
pool := newOpenAIWSConnPool(&config.Config{})
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: &Account{ID: 1},
WSURL: " ",
})
require.Error(t, err)
require.Contains(t, err.Error(), "ws url is empty")
// target=nil 分支:池满且仅有 nil 连接
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
fullPool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := fullPool.getOrCreateAccountPool(account.ID)
ap.mu.Lock()
ap.conns["nil"] = nil
ap.lastCleanupAt = time.Now()
ap.mu.Unlock()
_, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.ErrorIs(t, err, errOpenAIWSConnClosed)
// queue full 分支:waiters 达上限
account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap2 := fullPool.getOrCreateAccountPool(account2.ID)
conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil)
require.True(t, conn.tryAcquire())
conn.waiters.Store(1)
ap2.mu.Lock()
ap2.conns[conn.id] = conn
ap2.lastCleanupAt = time.Now()
ap2.mu.Unlock()
_, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account2,
WSURL: "wss://example.com/v1/responses",
})
require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
}
type openAIWSFakeDialer struct{}
func (d *openAIWSFakeDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
return &openAIWSFakeConn{}, 0, nil, nil
}
type openAIWSCountingDialer struct {
mu sync.Mutex
dialCount int
}
type openAIWSAlwaysFailDialer struct {
mu sync.Mutex
dialCount int
}
type openAIWSPingBlockingConn struct {
current *atomic.Int32
maxConcurrent *atomic.Int32
release <-chan struct{}
}
func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error {
return nil
}
func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil
}
func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error {
if c.current == nil || c.maxConcurrent == nil {
return nil
}
now := c.current.Add(1)
for {
prev := c.maxConcurrent.Load()
if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) {
break
}
}
defer c.current.Add(-1)
select {
case <-ctx.Done():
return ctx.Err()
case <-c.release:
return nil
}
}
func (c *openAIWSPingBlockingConn) Close() error {
return nil
}
func (d *openAIWSCountingDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
d.mu.Lock()
d.dialCount++
d.mu.Unlock()
return &openAIWSFakeConn{}, 0, nil, nil
}
func (d *openAIWSCountingDialer) DialCount() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.dialCount
}
func (d *openAIWSAlwaysFailDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
d.mu.Lock()
d.dialCount++
d.mu.Unlock()
return nil, 503, nil, errors.New("dial failed")
}
func (d *openAIWSAlwaysFailDialer) DialCount() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.dialCount
}
type openAIWSFakeConn struct {
mu sync.Mutex
closed bool
payload [][]byte
}
func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
_ = ctx
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errors.New("closed")
}
c.payload = append(c.payload, []byte("ok"))
_ = value
return nil
}
func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
_ = ctx
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, errors.New("closed")
}
return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
}
func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
_ = ctx
return nil
}
func (c *openAIWSFakeConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
type openAIWSBlockingConn struct {
readDelay time.Duration
}
func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error {
_ = ctx
_ = value
return nil
}
func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) {
delay := c.readDelay
if delay <= 0 {
delay = 10 * time.Millisecond
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil
}
}
func (c *openAIWSBlockingConn) Ping(ctx context.Context) error {
_ = ctx
return nil
}
func (c *openAIWSBlockingConn) Close() error {
return nil
}
type openAIWSWriteBlockingConn struct{}
func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error {
<-ctx.Done()
return ctx.Err()
}
func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil
}
func (c *openAIWSWriteBlockingConn) Ping(context.Context) error {
return nil
}
func (c *openAIWSWriteBlockingConn) Close() error {
return nil
}
type openAIWSPingFailConn struct{}
func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
return nil
}
func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
}
func (c *openAIWSPingFailConn) Ping(context.Context) error {
return errors.New("ping failed")
}
func (c *openAIWSPingFailConn) Close() error {
return nil
}
type openAIWSContextProbeConn struct {
lastWriteCtx context.Context
}
func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error {
c.lastWriteCtx = ctx
return nil
}
func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil
}
func (c *openAIWSContextProbeConn) Ping(context.Context) error {
return nil
}
func (c *openAIWSContextProbeConn) Close() error {
return nil
}
type openAIWSNilConnDialer struct{}
func (d *openAIWSNilConnDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
return nil, 200, nil, nil
}
func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSNilConnDialer{})
account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.Error(t, err)
require.Contains(t, err.Error(), "nil connection")
}
func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) {
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
_, err := dialer.proxyHTTPClient("http://127.0.0.1:28080")
require.NoError(t, err)
_, err = dialer.proxyHTTPClient("http://127.0.0.1:28080")
require.NoError(t, err)
_, err = dialer.proxyHTTPClient("http://127.0.0.1:28081")
require.NoError(t, err)
snapshot := pool.SnapshotTransportMetrics()
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
}
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 1,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP")
}
func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 101,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游")
require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id")
decision, _ := c.Get("openai_ws_transport_decision")
reason, _ := c.Get("openai_ws_transport_reason")
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = false
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 1,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUpgradeRequired)
_, _ = w.Write([]byte(`upgrade required`))
}))
defer ws426Server.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 12,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": ws426Server.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "upgrade_required")
require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
require.Equal(t, http.StatusUpgradeRequired, rec.Code)
require.Contains(t, rec.Body.String(), "426")
}
func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
gin.SetMode(gin.TestMode)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 21,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required")
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
_, ok := c.Get("openai_ws_fallback_cooling")
require.False(t, ok, "已移除 fallback cooling 快捷回退路径")
}
func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 31,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1/responses",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "ws v1")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "WSv1")
require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
}
func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
cfg := &config.Config{}
svc := NewOpenAIGatewayService(
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_missing", decision.Reason)
}
func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) {
gin.SetMode(gin.TestMode)
ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUpgradeRequired)
_, _ = w.Write([]byte(`upgrade required`))
}))
defer ws426Server.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
c.String(http.StatusAccepted, "already-written")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 41,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": ws426Server.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "ws fallback")
require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP")
}
func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
// 仅发送 response.created(非 token 事件)后立即关闭,
// 模拟线上“上游早期内部错误断连”的场景。
if err := conn.WriteJSON(map[string]any{
"type": "response.created",
"response": map[string]any{
"id": "resp_ws_created_only",
"model": "gpt-5.3-codex",
},
}); err != nil {
t.Errorf("write response.created failed: %v", err)
return
}
closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "")
_ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 88,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP")
require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流")
}
func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "")
_ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 89,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP")
require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load())
}
func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "")
_ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1
cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2
cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 8901,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试")
}
func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "websocket_connection_limit_reached",
"type": "server_error",
"message": "websocket connection limit reached",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 90,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP")
require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load())
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_prev_recover_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 91,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试")
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String())
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 92,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 1)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 93,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试")
require.Equal(t, http.StatusBadRequest, rec.Code)
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 1)
require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 94,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试")
require.Equal(t, http.StatusBadRequest, rec.Code)
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
}
package service
import "github.com/Wei-Shaw/sub2api/internal/config"
// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
type OpenAIUpstreamTransport string
const (
OpenAIUpstreamTransportAny OpenAIUpstreamTransport = ""
OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse"
OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets"
OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2"
)
// OpenAIWSProtocolDecision 表示协议决策结果。
type OpenAIWSProtocolDecision struct {
Transport OpenAIUpstreamTransport
Reason string
}
// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
type OpenAIWSProtocolResolver interface {
Resolve(account *Account) OpenAIWSProtocolDecision
}
type defaultOpenAIWSProtocolResolver struct {
cfg *config.Config
}
// NewOpenAIWSProtocolResolver 创建默认协议决策器。
func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver {
return &defaultOpenAIWSProtocolResolver{cfg: cfg}
}
func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
if account == nil {
return openAIWSHTTPDecision("account_missing")
}
if !account.IsOpenAI() {
return openAIWSHTTPDecision("platform_not_openai")
}
if account.IsOpenAIWSForceHTTPEnabled() {
return openAIWSHTTPDecision("account_force_http")
}
if r == nil || r.cfg == nil {
return openAIWSHTTPDecision("config_missing")
}
wsCfg := r.cfg.Gateway.OpenAIWS
if wsCfg.ForceHTTP {
return openAIWSHTTPDecision("global_force_http")
}
if !wsCfg.Enabled {
return openAIWSHTTPDecision("global_disabled")
}
if account.IsOpenAIOAuth() {
if !wsCfg.OAuthEnabled {
return openAIWSHTTPDecision("oauth_disabled")
}
} else if account.IsOpenAIApiKey() {
if !wsCfg.APIKeyEnabled {
return openAIWSHTTPDecision("apikey_disabled")
}
} else {
return openAIWSHTTPDecision("unknown_auth_type")
}
if wsCfg.ModeRouterV2Enabled {
mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault)
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return openAIWSHTTPDecision("account_mode_off")
}
if account.Concurrency <= 0 {
return openAIWSHTTPDecision("account_concurrency_invalid")
}
if wsCfg.ResponsesWebsocketsV2 {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_mode_" + mode,
}
}
if wsCfg.ResponsesWebsockets {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocket,
Reason: "ws_v1_mode_" + mode,
}
}
return openAIWSHTTPDecision("feature_disabled")
}
if !account.IsOpenAIResponsesWebSocketV2Enabled() {
return openAIWSHTTPDecision("account_disabled")
}
if wsCfg.ResponsesWebsocketsV2 {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_enabled",
}
}
if wsCfg.ResponsesWebsockets {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocket,
Reason: "ws_v1_enabled",
}
}
return openAIWSHTTPDecision("feature_disabled")
}
func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportHTTPSSE,
Reason: reason,
}
}
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
baseCfg := &config.Config{}
baseCfg.Gateway.OpenAIWS.Enabled = true
baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
openAIOAuthEnabled := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
t.Run("v2优先", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("v2关闭时回退v1", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
require.Equal(t, "ws_v1_enabled", decision.Reason)
})
t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_passthrough": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("账号级强制HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_ws_force_http": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_force_http", decision.Reason)
})
t.Run("全局关闭保持HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.Enabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "global_disabled", decision.Reason)
})
t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_ws_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("按账号类型开关控制", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.OAuthEnabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "oauth_disabled", decision.Reason)
})
t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.APIKeyEnabled = false
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "apikey_disabled", decision.Reason)
})
t.Run("未知认证类型回退HTTP", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: "unknown_type",
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "unknown_auth_type", decision.Reason)
})
}
func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
offAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
invalidConcurrency := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_concurrency_invalid", decision.Reason)
})
}
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIWSResponseAccountCachePrefix = "openai:response:"
openAIWSStateStoreCleanupInterval = time.Minute
openAIWSStateStoreCleanupMaxPerMap = 512
openAIWSStateStoreMaxEntriesPerMap = 65536
openAIWSStateStoreRedisTimeout = 3 * time.Second
)
type openAIWSAccountBinding struct {
accountID int64
expiresAt time.Time
}
type openAIWSConnBinding struct {
connID string
expiresAt time.Time
}
type openAIWSTurnStateBinding struct {
turnState string
expiresAt time.Time
}
type openAIWSSessionConnBinding struct {
connID string
expiresAt time.Time
}
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
//
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
// response_id -> conn_id 仅在本进程内有效。
type OpenAIWSStateStore interface {
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
BindResponseConn(responseID, connID string, ttl time.Duration)
GetResponseConn(responseID string) (string, bool)
DeleteResponseConn(responseID string)
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
DeleteSessionTurnState(groupID int64, sessionHash string)
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
GetSessionConn(groupID int64, sessionHash string) (string, bool)
DeleteSessionConn(groupID int64, sessionHash string)
}
type defaultOpenAIWSStateStore struct {
cache GatewayCache
responseToAccountMu sync.RWMutex
responseToAccount map[string]openAIWSAccountBinding
responseToConnMu sync.RWMutex
responseToConn map[string]openAIWSConnBinding
sessionToTurnStateMu sync.RWMutex
sessionToTurnState map[string]openAIWSTurnStateBinding
sessionToConnMu sync.RWMutex
sessionToConn map[string]openAIWSSessionConnBinding
lastCleanupUnixNano atomic.Int64
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
store := &defaultOpenAIWSStateStore{
cache: cache,
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
responseToConn: make(map[string]openAIWSConnBinding, 256),
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
}
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
return store
}
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" || accountID <= 0 {
return nil
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
expiresAt := time.Now().Add(ttl)
s.responseToAccountMu.Lock()
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
}
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return 0, nil
}
s.maybeCleanup()
now := time.Now()
s.responseToAccountMu.RLock()
if binding, ok := s.responseToAccount[id]; ok {
if now.Before(binding.expiresAt) {
accountID := binding.accountID
s.responseToAccountMu.RUnlock()
return accountID, nil
}
}
s.responseToAccountMu.RUnlock()
if s.cache == nil {
return 0, nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
if err != nil || accountID <= 0 {
// 缓存读取失败不阻断主流程,按未命中降级。
return 0, nil
}
return accountID, nil
}
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return nil
}
s.responseToAccountMu.Lock()
delete(s.responseToAccount, id)
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
}
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
id := normalizeOpenAIWSResponseID(responseID)
conn := strings.TrimSpace(connID)
if id == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.responseToConnMu.Lock()
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToConn[id] = openAIWSConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.responseToConnMu.RLock()
binding, ok := s.responseToConn[id]
s.responseToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return
}
s.responseToConnMu.Lock()
delete(s.responseToConn, id)
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToTurnStateMu.Lock()
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
turnState: state,
expiresAt: time.Now().Add(ttl),
}
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToTurnStateMu.RLock()
binding, ok := s.sessionToTurnState[key]
s.sessionToTurnStateMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
return "", false
}
return binding.turnState, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToTurnStateMu.Lock()
delete(s.sessionToTurnState, key)
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
conn := strings.TrimSpace(connID)
if key == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToConnMu.Lock()
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToConn[key] = openAIWSSessionConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToConnMu.RLock()
binding, ok := s.sessionToConn[key]
s.sessionToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToConnMu.Lock()
delete(s.sessionToConn, key)
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
if s == nil {
return
}
now := time.Now()
last := time.Unix(0, s.lastCleanupUnixNano.Load())
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
return
}
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
return
}
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
s.responseToAccountMu.Lock()
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToAccountMu.Unlock()
s.responseToConnMu.Lock()
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToConnMu.Unlock()
s.sessionToTurnStateMu.Lock()
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToTurnStateMu.Unlock()
s.sessionToConnMu.Lock()
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToConnMu.Unlock()
}
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
if len(bindings) < maxEntries || maxEntries <= 0 {
return
}
if _, exists := bindings[incomingKey]; exists {
return
}
// 固定上限保护:淘汰任意一项,优先保证内存有界。
for key := range bindings {
delete(bindings, key)
return
}
}
func normalizeOpenAIWSResponseID(responseID string) string {
return strings.TrimSpace(responseID)
}
func openAIWSResponseAccountCacheKey(responseID string) string {
sum := sha256.Sum256([]byte(responseID))
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
}
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return time.Hour
}
return ttl
}
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
hash := strings.TrimSpace(sessionHash)
if hash == "" {
return ""
}
return fmt.Sprintf("%d:%s", groupID, hash)
}
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
}
package service
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) {
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
ctx := context.Background()
groupID := int64(7)
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute))
accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc")
require.NoError(t, err)
require.Equal(t, int64(101), accountID)
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc"))
accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc")
require.NoError(t, err)
require.Zero(t, accountID)
}
func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond)
connID, ok := store.GetResponseConn("resp_conn")
require.True(t, ok)
require.Equal(t, "conn_1", connID)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetResponseConn("resp_conn")
require.False(t, ok)
}
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
state, ok := store.GetSessionTurnState(9, "session_hash_1")
require.True(t, ok)
require.Equal(t, "turn_state_1", state)
// group 隔离
_, ok = store.GetSessionTurnState(10, "session_hash_1")
require.False(t, ok)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetSessionTurnState(9, "session_hash_1")
require.False(t, ok)
}
func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond)
connID, ok := store.GetSessionConn(9, "session_hash_conn_1")
require.True(t, ok)
require.Equal(t, "conn_1", connID)
// group 隔离
_, ok = store.GetSessionConn(10, "session_hash_conn_1")
require.False(t, ok)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetSessionConn(9, "session_hash_conn_1")
require.False(t, ok)
}
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
store := NewOpenAIWSStateStore(cache)
ctx := context.Background()
groupID := int64(17)
responseID := "resp_cache_stale"
cacheKey := openAIWSResponseAccountCacheKey(responseID)
cache.sessionBindings[cacheKey] = 501
accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
require.NoError(t, err)
require.Equal(t, int64(501), accountID)
delete(cache.sessionBindings, cacheKey)
accountID, err = store.GetResponseAccount(ctx, groupID, responseID)
require.NoError(t, err)
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
}
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
raw := NewOpenAIWSStateStore(nil)
store, ok := raw.(*defaultOpenAIWSStateStore)
require.True(t, ok)
expiredAt := time.Now().Add(-time.Minute)
total := 2048
store.responseToConnMu.Lock()
for i := 0; i < total; i++ {
store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
connID: "conn_incremental",
expiresAt: expiredAt,
}
}
store.responseToConnMu.Unlock()
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
store.responseToConnMu.RLock()
remainingAfterFirst := len(store.responseToConn)
store.responseToConnMu.RUnlock()
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
for i := 0; i < 8; i++ {
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
}
store.responseToConnMu.RLock()
remaining := len(store.responseToConn)
store.responseToConnMu.RUnlock()
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
}
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
bindings := map[string]int{
"a": 1,
"b": 2,
}
ensureBindingCapacity(bindings, "c", 2)
bindings["c"] = 3
require.Len(t, bindings, 2)
require.Equal(t, 3, bindings["c"])
}
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
bindings := map[string]int{
"a": 1,
"b": 2,
}
ensureBindingCapacity(bindings, "a", 2)
bindings["a"] = 9
require.Len(t, bindings, 2)
require.Equal(t, 9, bindings["a"])
}
type openAIWSStateStoreTimeoutProbeCache struct {
setHasDeadline bool
getHasDeadline bool
deleteHasDeadline bool
setDeadlineDelta time.Duration
getDeadlineDelta time.Duration
delDeadlineDelta time.Duration
}
func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) {
if deadline, ok := ctx.Deadline(); ok {
c.getHasDeadline = true
c.getDeadlineDelta = time.Until(deadline)
}
return 123, nil
}
func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error {
if deadline, ok := ctx.Deadline(); ok {
c.setHasDeadline = true
c.setDeadlineDelta = time.Until(deadline)
}
return errors.New("set failed")
}
func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
return nil
}
func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error {
if deadline, ok := ctx.Deadline(); ok {
c.deleteHasDeadline = true
c.delDeadlineDelta = time.Until(deadline)
}
return nil
}
func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
probe := &openAIWSStateStoreTimeoutProbeCache{}
store := NewOpenAIWSStateStore(probe)
ctx := context.Background()
groupID := int64(5)
err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute)
require.Error(t, err)
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
require.NoError(t, getErr)
require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second)
probe2 := &openAIWSStateStoreTimeoutProbeCache{}
store2 := NewOpenAIWSStateStore(probe2)
accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only")
require.NoError(t, err2)
require.Equal(t, int64(123), accountID2)
require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文")
require.Greater(t, probe2.getDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
}
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
defer cancel()
require.NotNil(t, ctx)
_, ok := ctx.Deadline()
require.True(t, ok, "应附加短超时")
}
......@@ -13,7 +13,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
......@@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
attemptCtx := ctx
if switches > 0 {
attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false)
}
exec := func() *opsRetryExecution {
defer selection.ReleaseFunc()
......@@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.
}
c.Request = req
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
return c, w
}
......
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) {
errorLog := &OpsErrorLogDetail{
OpsErrorLog: OpsErrorLog{
RequestPath: "/openai/v1/responses",
},
UserAgent: "ops-retry-agent/1.0",
RequestHeaders: `{
"anthropic-beta":"beta-v1",
"ANTHROPIC-VERSION":"2023-06-01",
"authorization":"Bearer should-not-forward"
}`,
}
c, w := newOpsRetryContext(context.Background(), errorLog)
require.NotNil(t, c)
require.NotNil(t, w)
require.NotNil(t, c.Request)
require.Equal(t, "/openai/v1/responses", c.Request.URL.Path)
require.Equal(t, "application/json", c.Request.Header.Get("Content-Type"))
require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent"))
require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta"))
require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version"))
require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放")
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
}
func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) {
errorLog := &OpsErrorLogDetail{
RequestHeaders: "{invalid-json",
}
c, _ := newOpsRetryContext(context.Background(), errorLog)
require.NotNil(t, c)
require.NotNil(t, c.Request)
require.Equal(t, "/", c.Request.URL.Path)
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
}
......@@ -27,6 +27,11 @@ const (
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
OpsResponseLatencyMsKey = "ops_response_latency_ms"
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
// OpenAI WS 关键观测字段
OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms"
OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms"
OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused"
OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
......
......@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// RateLimitService 处理限流和过载状态管理
......@@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct {
totals GeminiUsageTotals
}
type geminiUsageTotalsBatchProvider interface {
GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error)
}
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
......@@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
logger.LegacyPrintf(
"service.ratelimit",
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
account.ID,
account.Platform,
account.Type,
strings.TrimSpace(headers.Get("x-request-id")),
strings.TrimSpace(headers.Get("cf-ray")),
upstreamMsg,
truncateForLog(responseBody, 1024),
)
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 429:
......@@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return true, err
}
......@@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return true, err
}
......@@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return true, nil
}
// PreCheckUsageBatch performs quota precheck for multiple accounts in one request.
// Returned map value=false means the account should be skipped.
func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) {
result := make(map[int64]bool, len(accounts))
for _, account := range accounts {
if account == nil {
continue
}
result[account.ID] = true
}
if len(accounts) == 0 || requestedModel == "" {
return result, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
return result, nil
}
modelClass := geminiModelClassFromName(requestedModel)
now := time.Now()
dailyStart := geminiDailyWindowStart(now)
minuteStart := now.Truncate(time.Minute)
type quotaAccount struct {
account *Account
quota GeminiQuota
}
quotaAccounts := make([]quotaAccount, 0, len(accounts))
for _, account := range accounts {
if account == nil || account.Platform != PlatformGemini {
continue
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
continue
}
quotaAccounts = append(quotaAccounts, quotaAccount{
account: account,
quota: quota,
})
}
if len(quotaAccounts) == 0 {
return result, nil
}
// 1) Daily precheck (cached + batch DB fallback)
dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts))
dailyMissIDs := make([]int64, 0, len(quotaAccounts))
for _, item := range quotaAccounts {
limit := geminiDailyLimit(item.quota, modelClass)
if limit <= 0 {
continue
}
accountID := item.account.ID
if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok {
dailyTotalsByID[accountID] = totals
continue
}
dailyMissIDs = append(dailyMissIDs, accountID)
}
if len(dailyMissIDs) > 0 {
totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now)
if err != nil {
return result, err
}
for _, accountID := range dailyMissIDs {
totals := totalsBatch[accountID]
dailyTotalsByID[accountID] = totals
s.setGeminiUsageTotals(accountID, dailyStart, now, totals)
}
}
for _, item := range quotaAccounts {
limit := geminiDailyLimit(item.quota, modelClass)
if limit <= 0 {
continue
}
accountID := item.account.ID
used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true)
if used >= limit {
resetAt := geminiDailyResetTime(now)
slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
result[accountID] = false
}
}
// 2) Minute precheck (batch DB)
minuteIDs := make([]int64, 0, len(quotaAccounts))
for _, item := range quotaAccounts {
accountID := item.account.ID
if !result[accountID] {
continue
}
if geminiMinuteLimit(item.quota, modelClass) <= 0 {
continue
}
minuteIDs = append(minuteIDs, accountID)
}
if len(minuteIDs) == 0 {
return result, nil
}
minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now)
if err != nil {
return result, err
}
for _, item := range quotaAccounts {
accountID := item.account.ID
if !result[accountID] {
continue
}
limit := geminiMinuteLimit(item.quota, modelClass)
if limit <= 0 {
continue
}
used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false)
if used >= limit {
resetAt := minuteStart.Add(time.Minute)
slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
result[accountID] = false
}
}
return result, nil
}
func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) {
result := make(map[int64]GeminiUsageTotals, len(accountIDs))
if len(accountIDs) == 0 {
return result, nil
}
ids := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, accountID := range accountIDs {
if accountID <= 0 {
continue
}
if _, ok := seen[accountID]; ok {
continue
}
seen[accountID] = struct{}{}
ids = append(ids, accountID)
}
if len(ids) == 0 {
return result, nil
}
if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok {
stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end)
if err != nil {
return nil, err
}
for _, accountID := range ids {
result[accountID] = stats[accountID]
}
return result, nil
}
for _, accountID := range ids {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil)
if err != nil {
return nil, err
}
result[accountID] = geminiAggregateUsage(stats)
}
return result, nil
}
func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
if quota.SharedRPD > 0 {
return quota.SharedRPD
}
switch modelClass {
case geminiModelFlash:
return quota.FlashRPD
default:
return quota.ProRPD
}
}
func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
if quota.SharedRPM > 0 {
return quota.SharedRPM
}
switch modelClass {
case geminiModelFlash:
return quota.FlashRPM
default:
return quota.ProRPM
}
}
func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 {
if daily {
if quota.SharedRPD > 0 {
return totals.ProRequests + totals.FlashRequests
}
} else {
if quota.SharedRPM > 0 {
return totals.ProRequests + totals.FlashRequests
}
}
switch modelClass {
case geminiModelFlash:
return totals.FlashRequests
default:
return totals.ProRequests
}
}
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.RLock()
defer s.usageCacheMu.RUnlock()
......
package service
import (
"context"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
type requestMetadataContextKey struct{}
var requestMetadataKey = requestMetadataContextKey{}
type RequestMetadata struct {
IsMaxTokensOneHaikuRequest *bool
ThinkingEnabled *bool
PrefetchedStickyAccountID *int64
PrefetchedStickyGroupID *int64
SingleAccountRetry *bool
AccountSwitchCount *int
}
var (
requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64
requestMetadataFallbackThinkingEnabledTotal atomic.Int64
requestMetadataFallbackPrefetchedStickyAccount atomic.Int64
requestMetadataFallbackPrefetchedStickyGroup atomic.Int64
requestMetadataFallbackSingleAccountRetryTotal atomic.Int64
requestMetadataFallbackAccountSwitchCountTotal atomic.Int64
)
func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) {
return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(),
requestMetadataFallbackThinkingEnabledTotal.Load(),
requestMetadataFallbackPrefetchedStickyAccount.Load(),
requestMetadataFallbackPrefetchedStickyGroup.Load(),
requestMetadataFallbackSingleAccountRetryTotal.Load(),
requestMetadataFallbackAccountSwitchCountTotal.Load()
}
func metadataFromContext(ctx context.Context) *RequestMetadata {
if ctx == nil {
return nil
}
md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata)
return md
}
func updateRequestMetadata(
ctx context.Context,
bridgeOldKeys bool,
update func(md *RequestMetadata),
legacyBridge func(ctx context.Context) context.Context,
) context.Context {
if ctx == nil {
return nil
}
current := metadataFromContext(ctx)
next := &RequestMetadata{}
if current != nil {
*next = *current
}
update(next)
ctx = context.WithValue(ctx, requestMetadataKey, next)
if bridgeOldKeys && legacyBridge != nil {
ctx = legacyBridge(ctx)
}
return ctx
}
func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.IsMaxTokensOneHaikuRequest = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value)
})
}
func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.ThinkingEnabled = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.ThinkingEnabled, value)
})
}
func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
account := accountID
group := groupID
md.PrefetchedStickyAccountID = &account
md.PrefetchedStickyGroupID = &group
}, func(base context.Context) context.Context {
bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID)
return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID)
})
}
func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.SingleAccountRetry = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.SingleAccountRetry, value)
})
}
func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.AccountSwitchCount = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.AccountSwitchCount, value)
})
}
func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil {
return *md.IsMaxTokensOneHaikuRequest, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok {
requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1)
return value, true
}
return false, false
}
func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil {
return *md.ThinkingEnabled, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
requestMetadataFallbackThinkingEnabledTotal.Add(1)
return value, true
}
return false, false
}
func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil {
return *md.PrefetchedStickyGroupID, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
switch t := v.(type) {
case int64:
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
return t, true
case int:
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
return int64(t), true
}
return 0, false
}
func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) {
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil {
return *md.PrefetchedStickyAccountID, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
switch t := v.(type) {
case int64:
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
return t, true
case int:
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
return int64(t), true
}
return 0, false
}
func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil {
return *md.SingleAccountRetry, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok {
requestMetadataFallbackSingleAccountRetryTotal.Add(1)
return value, true
}
return false, false
}
func AccountSwitchCountFromContext(ctx context.Context) (int, bool) {
if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil {
return *md.AccountSwitchCount, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.AccountSwitchCount)
switch t := v.(type) {
case int:
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
return t, true
case int64:
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
return int(t), true
}
return 0, false
}
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) {
ctx := context.Background()
ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false)
ctx = WithThinkingEnabled(ctx, true, false)
ctx = WithPrefetchedStickySession(ctx, 123, 456, false)
ctx = WithSingleAccountRetry(ctx, true, false)
ctx = WithAccountSwitchCount(ctx, 2, false)
isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
require.True(t, ok)
require.True(t, isHaiku)
thinking, ok := ThinkingEnabledFromContext(ctx)
require.True(t, ok)
require.True(t, thinking)
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(123), accountID)
groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(456), groupID)
singleRetry, ok := SingleAccountRetryFromContext(ctx)
require.True(t, ok)
require.True(t, singleRetry)
switchCount, ok := AccountSwitchCountFromContext(ctx)
require.True(t, ok)
require.Equal(t, 2, switchCount)
require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled))
require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID))
require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID))
require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry))
require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount))
}
func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) {
ctx := context.Background()
ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true)
ctx = WithThinkingEnabled(ctx, true, true)
ctx = WithPrefetchedStickySession(ctx, 123, 456, true)
ctx = WithSingleAccountRetry(ctx, true, true)
ctx = WithAccountSwitchCount(ctx, 2, true)
require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled))
require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID))
require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID))
require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry))
require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount))
}
func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) {
beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats()
ctx := context.Background()
ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true)
ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654))
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3))
isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
require.True(t, ok)
require.True(t, isHaiku)
thinking, ok := ThinkingEnabledFromContext(ctx)
require.True(t, ok)
require.True(t, thinking)
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(321), accountID)
groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(654), groupID)
singleRetry, ok := SingleAccountRetryFromContext(ctx)
require.True(t, ok)
require.True(t, singleRetry)
switchCount, ok := AccountSwitchCountFromContext(ctx)
require.True(t, ok)
require.Equal(t, 3, switchCount)
afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats()
require.Equal(t, beforeHaiku+1, afterHaiku)
require.Equal(t, beforeThinking+1, afterThinking)
require.Equal(t, beforeAccount+1, afterAccount)
require.Equal(t, beforeGroup+1, afterGroup)
require.Equal(t, beforeSingleRetry+1, afterSingleRetry)
require.Equal(t, beforeSwitchCount+1, afterSwitchCount)
}
func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false)
ctx = WithThinkingEnabled(ctx, true, false)
thinking, ok := ThinkingEnabledFromContext(ctx)
require.True(t, ok)
require.True(t, thinking)
require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled))
}
package service
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
)
func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter {
if cfg == nil {
return nil
}
return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders)
}
......@@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
if payload == nil {
return nil
}
ids := parseInt64Slice(payload["account_ids"])
for _, id := range ids {
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
return err
if s.accountRepo == nil {
return nil
}
rawIDs := parseInt64Slice(payload["account_ids"])
if len(rawIDs) == 0 {
return nil
}
ids := make([]int64, 0, len(rawIDs))
seen := make(map[int64]struct{}, len(rawIDs))
for _, id := range rawIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
return nil
if len(ids) == 0 {
return nil
}
preloadGroupIDs := parseInt64Slice(payload["group_ids"])
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return err
}
found := make(map[int64]struct{}, len(accounts))
rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs))
for _, gid := range preloadGroupIDs {
if gid > 0 {
rebuildGroupSet[gid] = struct{}{}
}
}
for _, account := range accounts {
if account == nil || account.ID <= 0 {
continue
}
found[account.ID] = struct{}{}
if s.cache != nil {
if err := s.cache.SetAccount(ctx, account); err != nil {
return err
}
}
for _, gid := range account.GroupIDs {
if gid > 0 {
rebuildGroupSet[gid] = struct{}{}
}
}
}
if s.cache != nil {
for _, id := range ids {
if _, ok := found[id]; ok {
continue
}
if err := s.cache.DeleteAccount(ctx, id); err != nil {
return err
}
}
}
rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet))
for gid := range rebuildGroupSet {
rebuildGroupIDs = append(rebuildGroupIDs, gid)
}
return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change")
}
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
......
......@@ -9,14 +9,17 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
)
type SettingRepository interface {
......@@ -34,6 +37,7 @@ type SettingService struct {
settingRepo SettingRepository
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
onS3Update func() // Callback when Sora S3 settings are updated
version string // Application version
}
......@@ -76,6 +80,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
SettingKeyLinuxDoConnectEnabled,
}
......@@ -114,6 +119,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
......@@ -124,6 +130,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
s.onS3Update = callback
}
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
......@@ -157,6 +168,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
......@@ -178,6 +190,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
......@@ -232,6 +245,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
......@@ -383,6 +397,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
......@@ -436,6 +451,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
}
// 解析整数类型
......@@ -854,3 +870,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
type soraS3ProfilesStore struct {
ActiveProfileID string `json:"active_profile_id"`
Items []soraS3ProfileStoreItem `json:"items"`
}
type soraS3ProfileStoreItem struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
profiles, err := s.ListSoraS3Profiles(ctx)
if err != nil {
return nil, err
}
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if activeProfile == nil {
return &SoraS3Settings{}, nil
}
return &SoraS3Settings{
Enabled: activeProfile.Enabled,
Endpoint: activeProfile.Endpoint,
Region: activeProfile.Region,
Bucket: activeProfile.Bucket,
AccessKeyID: activeProfile.AccessKeyID,
SecretAccessKey: activeProfile.SecretAccessKey,
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
Prefix: activeProfile.Prefix,
ForcePathStyle: activeProfile.ForcePathStyle,
CDNURL: activeProfile.CDNURL,
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
}, nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
now := time.Now().UTC().Format(time.RFC3339)
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
if activeIndex < 0 {
activeID := "default"
if hasSoraS3ProfileID(store.Items, activeID) {
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
}
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: activeID,
Name: "Default",
UpdatedAt: now,
})
store.ActiveProfileID = activeID
activeIndex = len(store.Items) - 1
}
active := store.Items[activeIndex]
active.Enabled = settings.Enabled
active.Endpoint = strings.TrimSpace(settings.Endpoint)
active.Region = strings.TrimSpace(settings.Region)
active.Bucket = strings.TrimSpace(settings.Bucket)
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
active.Prefix = strings.TrimSpace(settings.Prefix)
active.ForcePathStyle = settings.ForcePathStyle
active.CDNURL = strings.TrimSpace(settings.CDNURL)
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
if settings.SecretAccessKey != "" {
active.SecretAccessKey = settings.SecretAccessKey
}
active.UpdatedAt = now
store.Items[activeIndex] = active
return s.persistSoraS3ProfilesStore(ctx, store)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
return convertSoraS3ProfilesStore(store), nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
profileID := strings.TrimSpace(profile.ProfileID)
if profileID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
if hasSoraS3ProfileID(store.Items, profileID) {
return nil, ErrSoraS3ProfileExists
}
now := time.Now().UTC().Format(time.RFC3339)
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: profileID,
Name: name,
Enabled: profile.Enabled,
Endpoint: strings.TrimSpace(profile.Endpoint),
Region: strings.TrimSpace(profile.Region),
Bucket: strings.TrimSpace(profile.Bucket),
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
SecretAccessKey: profile.SecretAccessKey,
Prefix: strings.TrimSpace(profile.Prefix),
ForcePathStyle: profile.ForcePathStyle,
CDNURL: strings.TrimSpace(profile.CDNURL),
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
})
if setActive || store.ActiveProfileID == "" {
store.ActiveProfileID = profileID
}
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
created := findSoraS3ProfileByID(profiles.Items, profileID)
if created == nil {
return nil, ErrSoraS3ProfileNotFound
}
return created, nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
target := store.Items[targetIndex]
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
target.Name = name
target.Enabled = profile.Enabled
target.Endpoint = strings.TrimSpace(profile.Endpoint)
target.Region = strings.TrimSpace(profile.Region)
target.Bucket = strings.TrimSpace(profile.Bucket)
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
target.Prefix = strings.TrimSpace(profile.Prefix)
target.ForcePathStyle = profile.ForcePathStyle
target.CDNURL = strings.TrimSpace(profile.CDNURL)
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
if profile.SecretAccessKey != "" {
target.SecretAccessKey = profile.SecretAccessKey
}
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
store.Items[targetIndex] = target
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
updated := findSoraS3ProfileByID(profiles.Items, targetID)
if updated == nil {
return nil, ErrSoraS3ProfileNotFound
}
return updated, nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return ErrSoraS3ProfileNotFound
}
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
if store.ActiveProfileID == targetID {
store.ActiveProfileID = ""
if len(store.Items) > 0 {
store.ActiveProfileID = store.Items[0].ProfileID
}
}
return s.persistSoraS3ProfilesStore(ctx, store)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
store.ActiveProfileID = targetID
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if active == nil {
return nil, ErrSoraS3ProfileNotFound
}
return active, nil
}
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
if err == nil {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return &soraS3ProfilesStore{}, nil
}
var store soraS3ProfilesStore
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
normalized := normalizeSoraS3ProfilesStore(store)
return &normalized, nil
}
if !errors.Is(err, ErrSettingNotFound) {
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
}
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, legacyErr
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
if store == nil {
return fmt.Errorf("sora s3 profiles store cannot be nil")
}
normalized := normalizeSoraS3ProfilesStore(*store)
data, err := json.Marshal(normalized)
if err != nil {
return fmt.Errorf("marshal sora s3 profiles: %w", err)
}
updates := map[string]string{
SettingKeySoraS3Profiles: string(data),
}
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
if active == nil {
updates[SettingKeySoraS3Enabled] = "false"
updates[SettingKeySoraS3Endpoint] = ""
updates[SettingKeySoraS3Region] = ""
updates[SettingKeySoraS3Bucket] = ""
updates[SettingKeySoraS3AccessKeyID] = ""
updates[SettingKeySoraS3Prefix] = ""
updates[SettingKeySoraS3ForcePathStyle] = "false"
updates[SettingKeySoraS3CDNURL] = ""
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
updates[SettingKeySoraS3SecretAccessKey] = ""
} else {
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
}
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return err
}
if s.onUpdate != nil {
s.onUpdate()
}
if s.onS3Update != nil {
s.onS3Update()
}
return nil
}
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
keys := []string{
SettingKeySoraS3Enabled,
SettingKeySoraS3Endpoint,
SettingKeySoraS3Region,
SettingKeySoraS3Bucket,
SettingKeySoraS3AccessKeyID,
SettingKeySoraS3SecretAccessKey,
SettingKeySoraS3Prefix,
SettingKeySoraS3ForcePathStyle,
SettingKeySoraS3CDNURL,
SettingKeySoraDefaultStorageQuotaBytes,
}
values, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
}
result := &SoraS3Settings{
Enabled: values[SettingKeySoraS3Enabled] == "true",
Endpoint: values[SettingKeySoraS3Endpoint],
Region: values[SettingKeySoraS3Region],
Bucket: values[SettingKeySoraS3Bucket],
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
Prefix: values[SettingKeySoraS3Prefix],
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
CDNURL: values[SettingKeySoraS3CDNURL],
}
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
result.DefaultStorageQuotaBytes = v
}
return result, nil
}
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
seen := make(map[string]struct{}, len(store.Items))
normalized := soraS3ProfilesStore{
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
}
now := time.Now().UTC().Format(time.RFC3339)
for idx := range store.Items {
item := store.Items[idx]
item.ProfileID = strings.TrimSpace(item.ProfileID)
if item.ProfileID == "" {
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
}
if _, exists := seen[item.ProfileID]; exists {
continue
}
seen[item.ProfileID] = struct{}{}
item.Name = strings.TrimSpace(item.Name)
if item.Name == "" {
item.Name = item.ProfileID
}
item.Endpoint = strings.TrimSpace(item.Endpoint)
item.Region = strings.TrimSpace(item.Region)
item.Bucket = strings.TrimSpace(item.Bucket)
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
item.Prefix = strings.TrimSpace(item.Prefix)
item.CDNURL = strings.TrimSpace(item.CDNURL)
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
if item.UpdatedAt == "" {
item.UpdatedAt = now
}
normalized.Items = append(normalized.Items, item)
}
if len(normalized.Items) == 0 {
normalized.ActiveProfileID = ""
return normalized
}
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
return normalized
}
normalized.ActiveProfileID = normalized.Items[0].ProfileID
return normalized
}
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
if store == nil {
return &SoraS3ProfileList{}
}
items := make([]SoraS3Profile, 0, len(store.Items))
for idx := range store.Items {
item := store.Items[idx]
items = append(items, SoraS3Profile{
ProfileID: item.ProfileID,
Name: item.Name,
IsActive: item.ProfileID == store.ActiveProfileID,
Enabled: item.Enabled,
Endpoint: item.Endpoint,
Region: item.Region,
Bucket: item.Bucket,
AccessKeyID: item.AccessKeyID,
SecretAccessKey: item.SecretAccessKey,
SecretAccessKeyConfigured: item.SecretAccessKey != "",
Prefix: item.Prefix,
ForcePathStyle: item.ForcePathStyle,
CDNURL: item.CDNURL,
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
UpdatedAt: item.UpdatedAt,
})
}
return &SoraS3ProfileList{
ActiveProfileID: store.ActiveProfileID,
Items: items,
}
}
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
for idx := range items {
if items[idx].ProfileID == profileID {
return idx
}
}
return -1
}
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
return findSoraS3ProfileIndex(items, profileID) >= 0
}
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
if settings == nil {
return true
}
if settings.Enabled {
return false
}
if strings.TrimSpace(settings.Endpoint) != "" {
return false
}
if strings.TrimSpace(settings.Region) != "" {
return false
}
if strings.TrimSpace(settings.Bucket) != "" {
return false
}
if strings.TrimSpace(settings.AccessKeyID) != "" {
return false
}
if settings.SecretAccessKey != "" {
return false
}
if strings.TrimSpace(settings.Prefix) != "" {
return false
}
if strings.TrimSpace(settings.CDNURL) != "" {
return false
}
return settings.DefaultStorageQuotaBytes == 0
}
func maxInt64(value int64, min int64) int64 {
if value < min {
return min
}
return value
}
......@@ -39,6 +39,7 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
DefaultConcurrency int
DefaultBalance float64
......@@ -81,11 +82,52 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
LinuxDoOAuthEnabled bool
Version string
}
// SoraS3Settings Sora S3 存储配置
type SoraS3Settings struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type SoraS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type SoraS3ProfileList struct {
ActiveProfileID string `json:"active_profile_id"`
Items []SoraS3Profile `json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment