Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
bb664d9b
Commit
bb664d9b
authored
Feb 28, 2026
by
yangjianbo
Browse files
feat(sync): full code sync from release
parent
bfc7b339
Changes
245
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
245 of 245+
files are displayed.
Plain diff
Email patch
backend/internal/service/openai_ws_forwarder_retry_payload_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey
(
t
*
testing
.
T
)
{
payload
:=
map
[
string
]
any
{
"model"
:
"gpt-5.3-codex"
,
"prompt_cache_key"
:
"pcache_123"
,
"include"
:
[]
any
{
"reasoning.encrypted_content"
},
"text"
:
map
[
string
]
any
{
"verbosity"
:
"low"
,
},
"tools"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function"
}},
}
strategy
,
removed
:=
applyOpenAIWSRetryPayloadStrategy
(
payload
,
3
)
require
.
Equal
(
t
,
"trim_optional_fields"
,
strategy
)
require
.
Contains
(
t
,
removed
,
"include"
)
require
.
NotContains
(
t
,
removed
,
"prompt_cache_key"
)
require
.
Equal
(
t
,
"pcache_123"
,
payload
[
"prompt_cache_key"
])
require
.
NotContains
(
t
,
payload
,
"include"
)
require
.
Contains
(
t
,
payload
,
"text"
)
}
func
TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields
(
t
*
testing
.
T
)
{
payload
:=
map
[
string
]
any
{
"prompt_cache_key"
:
"pcache_456"
,
"instructions"
:
"long instructions"
,
"tools"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function"
}},
"parallel_tool_calls"
:
true
,
"tool_choice"
:
"auto"
,
"include"
:
[]
any
{
"reasoning.encrypted_content"
},
"text"
:
map
[
string
]
any
{
"verbosity"
:
"high"
},
}
strategy
,
removed
:=
applyOpenAIWSRetryPayloadStrategy
(
payload
,
6
)
require
.
Equal
(
t
,
"trim_optional_fields"
,
strategy
)
require
.
Contains
(
t
,
removed
,
"include"
)
require
.
NotContains
(
t
,
removed
,
"prompt_cache_key"
)
require
.
Equal
(
t
,
"pcache_456"
,
payload
[
"prompt_cache_key"
])
require
.
Contains
(
t
,
payload
,
"instructions"
)
require
.
Contains
(
t
,
payload
,
"tools"
)
require
.
Contains
(
t
,
payload
,
"tool_choice"
)
require
.
Contains
(
t
,
payload
,
"parallel_tool_calls"
)
require
.
Contains
(
t
,
payload
,
"text"
)
}
backend/internal/service/openai_ws_forwarder_success_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
type
receivedPayload
struct
{
Type
string
PreviousResponseID
string
StreamExists
bool
Stream
bool
}
receivedCh
:=
make
(
chan
receivedPayload
,
1
)
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
request
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
request
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
requestJSON
:=
requestToJSONString
(
request
)
receivedCh
<-
receivedPayload
{
Type
:
strings
.
TrimSpace
(
gjson
.
Get
(
requestJSON
,
"type"
)
.
String
()),
PreviousResponseID
:
strings
.
TrimSpace
(
gjson
.
Get
(
requestJSON
,
"previous_response_id"
)
.
String
()),
StreamExists
:
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Exists
(),
Stream
:
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Bool
(),
}
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.created"
,
"response"
:
map
[
string
]
any
{
"id"
:
"resp_new_1"
,
"model"
:
"gpt-5.1"
,
},
});
err
!=
nil
{
t
.
Errorf
(
"write response.created failed: %v"
,
err
)
return
}
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.completed"
,
"response"
:
map
[
string
]
any
{
"id"
:
"resp_new_1"
,
"model"
:
"gpt-5.1"
,
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
12
,
"output_tokens"
:
7
,
"input_tokens_details"
:
map
[
string
]
any
{
"cached_tokens"
:
3
,
},
},
},
});
err
!=
nil
{
t
.
Errorf
(
"write response.completed failed: %v"
,
err
)
return
}
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"unit-test-agent/1.0"
)
groupID
:=
int64
(
1001
)
c
.
Set
(
"api_key"
,
&
APIKey
{
GroupID
:
&
groupID
})
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
30
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
10
cfg
.
Gateway
.
OpenAIWS
.
StickyResponseIDTTLSeconds
=
3600
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cache
:=
&
stubGatewayCache
{}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
cache
:
cache
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
9
,
Name
:
"openai-ws"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
2
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
12
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
7
,
result
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
3
,
result
.
Usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
"resp_new_1"
,
result
.
RequestID
)
require
.
True
(
t
,
result
.
OpenAIWSMode
)
require
.
False
(
t
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"model"
)
.
Exists
(),
"WSv2 成功时不应回落 HTTP 上游"
)
received
:=
<-
receivedCh
require
.
Equal
(
t
,
"response.create"
,
received
.
Type
)
require
.
Equal
(
t
,
"resp_prev_1"
,
received
.
PreviousResponseID
)
require
.
True
(
t
,
received
.
StreamExists
,
"WS 请求应携带 stream 字段"
)
require
.
False
(
t
,
received
.
Stream
,
"应保持客户端 stream=false 的原始语义"
)
store
:=
svc
.
getOpenAIWSStateStore
()
mappedAccountID
,
getErr
:=
store
.
GetResponseAccount
(
context
.
Background
(),
groupID
,
"resp_new_1"
)
require
.
NoError
(
t
,
getErr
)
require
.
Equal
(
t
,
account
.
ID
,
mappedAccountID
)
connID
,
ok
:=
store
.
GetResponseConn
(
"resp_new_1"
)
require
.
True
(
t
,
ok
)
require
.
NotEmpty
(
t
,
connID
)
responseBody
:=
rec
.
Body
.
Bytes
()
require
.
Equal
(
t
,
"resp_new_1"
,
gjson
.
GetBytes
(
responseBody
,
"id"
)
.
String
())
}
func
requestToJSONString
(
payload
map
[
string
]
any
)
string
{
if
len
(
payload
)
==
0
{
return
"{}"
}
b
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
"{}"
}
return
string
(
b
)
}
func
TestLogOpenAIWSBindResponseAccountWarn
(
t
*
testing
.
T
)
{
require
.
NotPanics
(
t
,
func
()
{
logOpenAIWSBindResponseAccountWarn
(
1
,
2
,
"resp_ok"
,
nil
)
})
require
.
NotPanics
(
t
,
func
()
{
logOpenAIWSBindResponseAccountWarn
(
1
,
2
,
"resp_err"
,
errors
.
New
(
"bind failed"
))
})
}
func
TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.98.0"
)
groupID
:=
int64
(
3001
)
c
.
Set
(
"api_key"
,
&
APIKey
{
GroupID
:
&
groupID
})
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
5
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
3
captureConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
captureConn
}
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
captureDialer
)
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPool
:
pool
,
}
account
:=
&
Account
{
ID
:
1301
,
Name
:
"openai-rewrite"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"model_mapping"
:
map
[
string
]
any
{
"custom-original-model"
:
"gpt-5.1"
,
},
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_model_tool_1"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"custom-original-model"
,
gjson
.
GetBytes
(
rec
.
Body
.
Bytes
(),
"model"
)
.
String
(),
"响应模型应回写为原始请求模型"
)
require
.
Equal
(
t
,
"edit"
,
gjson
.
GetBytes
(
rec
.
Body
.
Bytes
(),
"tool_calls.0.function.name"
)
.
String
(),
"工具名称应被修正为 OpenCode 规范"
)
}
func
TestOpenAIWSPayloadString_OnlyAcceptsStringValues
(
t
*
testing
.
T
)
{
payload
:=
map
[
string
]
any
{
"type"
:
nil
,
"model"
:
123
,
"prompt_cache_key"
:
" cache-key "
,
"previous_response_id"
:
[]
byte
(
" resp_1 "
),
}
require
.
Equal
(
t
,
""
,
openAIWSPayloadString
(
payload
,
"type"
))
require
.
Equal
(
t
,
""
,
openAIWSPayloadString
(
payload
,
"model"
))
require
.
Equal
(
t
,
"cache-key"
,
openAIWSPayloadString
(
payload
,
"prompt_cache_key"
))
require
.
Equal
(
t
,
"resp_1"
,
openAIWSPayloadString
(
payload
,
"previous_response_id"
))
}
func
TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
upgradeCount
atomic
.
Int64
var
sequence
atomic
.
Int64
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
upgradeCount
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
for
{
var
request
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
request
);
err
!=
nil
{
return
}
idx
:=
sequence
.
Add
(
1
)
responseID
:=
"resp_reuse_"
+
strconv
.
FormatInt
(
idx
,
10
)
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.created"
,
"response"
:
map
[
string
]
any
{
"id"
:
responseID
,
"model"
:
"gpt-5.1"
,
},
});
err
!=
nil
{
return
}
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.completed"
,
"response"
:
map
[
string
]
any
{
"id"
:
responseID
,
"model"
:
"gpt-5.1"
,
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
2
,
"output_tokens"
:
1
,
},
},
});
err
!=
nil
{
return
}
}
}))
defer
wsServer
.
Close
()
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
30
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
10
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
19
,
Name
:
"openai-ws"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
2
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
for
i
:=
0
;
i
<
2
;
i
++
{
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.98.0"
)
groupID
:=
int64
(
2001
)
c
.
Set
(
"api_key"
,
&
APIKey
{
GroupID
:
&
groupID
})
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
strings
.
HasPrefix
(
result
.
RequestID
,
"resp_reuse_"
))
}
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
(),
"多个客户端请求应复用账号连接池而不是 1:1 对等建链"
)
metrics
:=
svc
.
SnapshotOpenAIWSPoolMetrics
()
require
.
GreaterOrEqual
(
t
,
metrics
.
AcquireReuseTotal
,
int64
(
1
))
require
.
GreaterOrEqual
(
t
,
metrics
.
ConnPickTotal
,
int64
(
1
))
}
func
TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.98.0"
)
c
.
Request
.
Header
.
Set
(
"session_id"
,
"sess-oauth-1"
)
c
.
Request
.
Header
.
Set
(
"conversation_id"
,
"conv-oauth-1"
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
AllowStoreRecovery
=
false
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
captureConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
captureConn
}
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
captureDialer
)
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPool
:
pool
,
}
account
:=
&
Account
{
ID
:
29
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token-1"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_oauth_1"
,
result
.
RequestID
)
require
.
NotNil
(
t
,
captureConn
.
lastWrite
)
requestJSON
:=
requestToJSONString
(
captureConn
.
lastWrite
)
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"store"
)
.
Exists
(),
"OAuth WSv2 应显式写入 store 字段"
)
require
.
False
(
t
,
gjson
.
Get
(
requestJSON
,
"store"
)
.
Bool
(),
"默认策略应将 OAuth store 置为 false"
)
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Exists
(),
"WSv2 payload 应保留 stream 字段"
)
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Bool
(),
"OAuth Codex 规范化后应强制 stream=true"
)
require
.
Equal
(
t
,
openAIWSBetaV2Value
,
captureDialer
.
lastHeaders
.
Get
(
"OpenAI-Beta"
))
require
.
Equal
(
t
,
"sess-oauth-1"
,
captureDialer
.
lastHeaders
.
Get
(
"session_id"
))
require
.
Equal
(
t
,
"conv-oauth-1"
,
captureDialer
.
lastHeaders
.
Get
(
"conversation_id"
))
}
func
TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.98.0"
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
captureConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
captureConn
}
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
captureDialer
)
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPool
:
pool
,
}
account
:=
&
Account
{
ID
:
31
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token-1"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_prompt_cache_key"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"pcache_123"
,
captureDialer
.
lastHeaders
.
Get
(
"session_id"
))
require
.
Empty
(
t
,
captureDialer
.
lastHeaders
.
Get
(
"conversation_id"
))
require
.
NotNil
(
t
,
captureConn
.
lastWrite
)
require
.
True
(
t
,
gjson
.
Get
(
requestToJSONString
(
captureConn
.
lastWrite
),
"stream"
)
.
Exists
())
}
func
TestOpenAIGatewayService_Forward_WSv1_Unsupported
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.98.0"
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsockets
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
false
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
39
,
Name
:
"openai-ws-v1"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
"https://api.openai.com/v1/responses"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"ws v1"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"WSv1"
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WSv1 不支持时不应触发 HTTP 上游请求"
)
}
func
TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
connIndex
atomic
.
Int64
headersCh
:=
make
(
chan
http
.
Header
,
4
)
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
idx
:=
connIndex
.
Add
(
1
)
headersCh
<-
cloneHeader
(
r
.
Header
)
respHeader
:=
http
.
Header
{}
if
idx
==
1
{
respHeader
.
Set
(
"x-codex-turn-state"
,
"turn_state_first"
)
}
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
respHeader
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
request
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
request
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
responseID
:=
"resp_turn_"
+
strconv
.
FormatInt
(
idx
,
10
)
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.completed"
,
"response"
:
map
[
string
]
any
{
"id"
:
responseID
,
"model"
:
"gpt-5.1"
,
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
2
,
"output_tokens"
:
1
,
},
},
});
err
!=
nil
{
t
.
Errorf
(
"write response.completed failed: %v"
,
err
)
return
}
}))
defer
wsServer
.
Close
()
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
0
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
49
,
Name
:
"openai-turn-state"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
reqBody
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
rec1
:=
httptest
.
NewRecorder
()
c1
,
_
:=
gin
.
CreateTestContext
(
rec1
)
c1
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c1
.
Request
.
Header
.
Set
(
"session_id"
,
"session_turn_state"
)
c1
.
Request
.
Header
.
Set
(
"x-codex-turn-metadata"
,
"turn_meta_1"
)
result1
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c1
,
account
,
reqBody
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result1
)
sessionHash
:=
svc
.
GenerateSessionHash
(
c1
,
reqBody
)
store
:=
svc
.
getOpenAIWSStateStore
()
turnState
,
ok
:=
store
.
GetSessionTurnState
(
0
,
sessionHash
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"turn_state_first"
,
turnState
)
// 主动淘汰连接,模拟下一次请求发生重连。
connID
,
hasConn
:=
store
.
GetResponseConn
(
result1
.
RequestID
)
require
.
True
(
t
,
hasConn
)
svc
.
getOpenAIWSConnPool
()
.
evictConn
(
account
.
ID
,
connID
)
rec2
:=
httptest
.
NewRecorder
()
c2
,
_
:=
gin
.
CreateTestContext
(
rec2
)
c2
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c2
.
Request
.
Header
.
Set
(
"session_id"
,
"session_turn_state"
)
c2
.
Request
.
Header
.
Set
(
"x-codex-turn-metadata"
,
"turn_meta_2"
)
result2
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c2
,
account
,
reqBody
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result2
)
firstHandshakeHeaders
:=
<-
headersCh
secondHandshakeHeaders
:=
<-
headersCh
require
.
Equal
(
t
,
"turn_meta_1"
,
firstHandshakeHeaders
.
Get
(
"X-Codex-Turn-Metadata"
))
require
.
Equal
(
t
,
"turn_meta_2"
,
secondHandshakeHeaders
.
Get
(
"X-Codex-Turn-Metadata"
))
require
.
Equal
(
t
,
"turn_state_first"
,
secondHandshakeHeaders
.
Get
(
"X-Codex-Turn-State"
))
}
func
TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"session_id"
,
"session-prewarm"
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
PrewarmGenerateEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
captureConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`
),
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
captureConn
}
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
captureDialer
)
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPool
:
pool
,
}
account
:=
&
Account
{
ID
:
59
,
Name
:
"openai-prewarm"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_main_1"
,
result
.
RequestID
)
require
.
Len
(
t
,
captureConn
.
writes
,
2
,
"开启 generate=false 预热后应发送两次 WS 请求"
)
firstWrite
:=
requestToJSONString
(
captureConn
.
writes
[
0
])
secondWrite
:=
requestToJSONString
(
captureConn
.
writes
[
1
])
require
.
True
(
t
,
gjson
.
Get
(
firstWrite
,
"generate"
)
.
Exists
())
require
.
False
(
t
,
gjson
.
Get
(
firstWrite
,
"generate"
)
.
Bool
())
require
.
False
(
t
,
gjson
.
Get
(
secondWrite
,
"generate"
)
.
Exists
())
}
func
TestOpenAIGatewayService_PrewarmReadHonorsParentContext
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
PrewarmGenerateEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
5
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
3
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
601
,
Name
:
"openai-prewarm-timeout"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
}
conn
:=
newOpenAIWSConn
(
"prewarm_ctx_conn"
,
account
.
ID
,
&
openAIWSBlockingConn
{
readDelay
:
200
*
time
.
Millisecond
,
},
nil
)
lease
:=
&
openAIWSConnLease
{
accountID
:
account
.
ID
,
conn
:
conn
,
}
payload
:=
map
[
string
]
any
{
"type"
:
"response.create"
,
"model"
:
"gpt-5.1"
,
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
40
*
time
.
Millisecond
)
defer
cancel
()
start
:=
time
.
Now
()
err
:=
svc
.
performOpenAIWSGeneratePrewarm
(
ctx
,
lease
,
OpenAIWSProtocolDecision
{
Transport
:
OpenAIUpstreamTransportResponsesWebsocketV2
},
payload
,
""
,
map
[
string
]
any
{
"model"
:
"gpt-5.1"
},
account
,
nil
,
0
,
)
elapsed
:=
time
.
Since
(
start
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"prewarm_read_event"
)
require
.
Less
(
t
,
elapsed
,
180
*
time
.
Millisecond
,
"预热读取应受父 context 取消控制,不应阻塞到 read_timeout"
)
}
func
TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
captureConn
:=
&
openAIWSCaptureConn
{
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`
),
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
captureConn
}
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
captureDialer
)
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPool
:
pool
,
}
account
:=
&
Account
{
ID
:
69
,
Name
:
"openai-turn-metadata"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
rec1
:=
httptest
.
NewRecorder
()
c1
,
_
:=
gin
.
CreateTestContext
(
rec1
)
c1
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c1
.
Request
.
Header
.
Set
(
"session_id"
,
"session-metadata-reuse"
)
c1
.
Request
.
Header
.
Set
(
"x-codex-turn-metadata"
,
"turn_meta_payload_1"
)
result1
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c1
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result1
)
require
.
Equal
(
t
,
"resp_meta_1"
,
result1
.
RequestID
)
rec2
:=
httptest
.
NewRecorder
()
c2
,
_
:=
gin
.
CreateTestContext
(
rec2
)
c2
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c2
.
Request
.
Header
.
Set
(
"session_id"
,
"session-metadata-reuse"
)
c2
.
Request
.
Header
.
Set
(
"x-codex-turn-metadata"
,
"turn_meta_payload_2"
)
result2
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c2
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result2
)
require
.
Equal
(
t
,
"resp_meta_2"
,
result2
.
RequestID
)
require
.
Equal
(
t
,
1
,
captureDialer
.
DialCount
(),
"同一账号两轮请求应复用同一 WS 连接"
)
require
.
Len
(
t
,
captureConn
.
writes
,
2
)
firstWrite
:=
requestToJSONString
(
captureConn
.
writes
[
0
])
secondWrite
:=
requestToJSONString
(
captureConn
.
writes
[
1
])
require
.
Equal
(
t
,
"turn_meta_payload_1"
,
gjson
.
Get
(
firstWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
require
.
Equal
(
t
,
"turn_meta_payload_2"
,
gjson
.
Get
(
secondWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
}
func
TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
upgradeCount
atomic
.
Int64
var
sequence
atomic
.
Int64
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
upgradeCount
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
for
{
var
request
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
request
);
err
!=
nil
{
return
}
responseID
:=
"resp_store_false_"
+
strconv
.
FormatInt
(
sequence
.
Add
(
1
),
10
)
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.completed"
,
"response"
:
map
[
string
]
any
{
"id"
:
responseID
,
"model"
:
"gpt-5.1"
,
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
1
,
"output_tokens"
:
1
,
},
},
});
err
!=
nil
{
return
}
}
}))
defer
wsServer
.
Close
()
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
StoreDisabledForceNewConn
=
true
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
79
,
Name
:
"openai-store-false"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
2
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`
)
rec1
:=
httptest
.
NewRecorder
()
c1
,
_
:=
gin
.
CreateTestContext
(
rec1
)
c1
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c1
.
Request
.
Header
.
Set
(
"session_id"
,
"session_store_false_a"
)
result1
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c1
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result1
)
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
())
rec2
:=
httptest
.
NewRecorder
()
c2
,
_
:=
gin
.
CreateTestContext
(
rec2
)
c2
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c2
.
Request
.
Header
.
Set
(
"session_id"
,
"session_store_false_a"
)
result2
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c2
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result2
)
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
(),
"同一 session(store=false) 应复用同一 WS 连接"
)
rec3
:=
httptest
.
NewRecorder
()
c3
,
_
:=
gin
.
CreateTestContext
(
rec3
)
c3
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c3
.
Request
.
Header
.
Set
(
"session_id"
,
"session_store_false_b"
)
result3
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c3
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result3
)
require
.
Equal
(
t
,
int64
(
2
),
upgradeCount
.
Load
(),
"不同 session(store=false) 应隔离连接,避免续链状态互相覆盖"
)
}
func
TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
upgradeCount
atomic
.
Int64
var
sequence
atomic
.
Int64
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
upgradeCount
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
for
{
var
request
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
request
);
err
!=
nil
{
return
}
responseID
:=
"resp_store_false_reuse_"
+
strconv
.
FormatInt
(
sequence
.
Add
(
1
),
10
)
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.completed"
,
"response"
:
map
[
string
]
any
{
"id"
:
responseID
,
"model"
:
"gpt-5.1"
,
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
1
,
"output_tokens"
:
1
,
},
},
});
err
!=
nil
{
return
}
}
}))
defer
wsServer
.
Close
()
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
StoreDisabledForceNewConn
=
false
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
&
httpUpstreamRecorder
{},
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
80
,
Name
:
"openai-store-false-reuse"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
2
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`
)
rec1
:=
httptest
.
NewRecorder
()
c1
,
_
:=
gin
.
CreateTestContext
(
rec1
)
c1
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c1
.
Request
.
Header
.
Set
(
"session_id"
,
"session_store_false_reuse_a"
)
result1
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c1
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result1
)
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
())
rec2
:=
httptest
.
NewRecorder
()
c2
,
_
:=
gin
.
CreateTestContext
(
rec2
)
c2
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c2
.
Request
.
Header
.
Set
(
"session_id"
,
"session_store_false_reuse_b"
)
result2
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c2
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result2
)
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
(),
"关闭强制新连后,不同 session(store=false) 可复用连接"
)
}
func
TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.98.0"
)
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
3
cfg
.
Gateway
.
OpenAIWS
.
ReadTimeoutSeconds
=
1
cfg
.
Gateway
.
OpenAIWS
.
WriteTimeoutSeconds
=
3
captureConn
:=
&
openAIWSCaptureConn
{
readDelays
:
[]
time
.
Duration
{
700
*
time
.
Millisecond
,
700
*
time
.
Millisecond
,
},
events
:
[][]
byte
{
[]
byte
(
`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`
),
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`
),
},
}
captureDialer
:=
&
openAIWSCaptureDialer
{
conn
:
captureConn
}
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
captureDialer
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
cache
:
&
stubGatewayCache
{},
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSPool
:
pool
,
}
account
:=
&
Account
{
ID
:
81
,
Name
:
"openai-read-timeout"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_timeout_ok"
,
result
.
RequestID
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP"
)
}
type
openAIWSCaptureDialer
struct
{
mu
sync
.
Mutex
conn
*
openAIWSCaptureConn
lastHeaders
http
.
Header
handshake
http
.
Header
dialCount
int
}
func
(
d
*
openAIWSCaptureDialer
)
Dial
(
ctx
context
.
Context
,
wsURL
string
,
headers
http
.
Header
,
proxyURL
string
,
)
(
openAIWSClientConn
,
int
,
http
.
Header
,
error
)
{
_
=
ctx
_
=
wsURL
_
=
proxyURL
d
.
mu
.
Lock
()
d
.
lastHeaders
=
cloneHeader
(
headers
)
d
.
dialCount
++
respHeaders
:=
cloneHeader
(
d
.
handshake
)
d
.
mu
.
Unlock
()
return
d
.
conn
,
0
,
respHeaders
,
nil
}
func
(
d
*
openAIWSCaptureDialer
)
DialCount
()
int
{
d
.
mu
.
Lock
()
defer
d
.
mu
.
Unlock
()
return
d
.
dialCount
}
type
openAIWSCaptureConn
struct
{
mu
sync
.
Mutex
readDelays
[]
time
.
Duration
events
[][]
byte
lastWrite
map
[
string
]
any
writes
[]
map
[
string
]
any
closed
bool
}
func
(
c
*
openAIWSCaptureConn
)
WriteJSON
(
ctx
context
.
Context
,
value
any
)
error
{
_
=
ctx
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
if
c
.
closed
{
return
errOpenAIWSConnClosed
}
switch
payload
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
c
.
lastWrite
=
cloneMapStringAny
(
payload
)
c
.
writes
=
append
(
c
.
writes
,
cloneMapStringAny
(
payload
))
case
json
.
RawMessage
:
var
parsed
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
payload
,
&
parsed
);
err
==
nil
{
c
.
lastWrite
=
cloneMapStringAny
(
parsed
)
c
.
writes
=
append
(
c
.
writes
,
cloneMapStringAny
(
parsed
))
}
case
[]
byte
:
var
parsed
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
payload
,
&
parsed
);
err
==
nil
{
c
.
lastWrite
=
cloneMapStringAny
(
parsed
)
c
.
writes
=
append
(
c
.
writes
,
cloneMapStringAny
(
parsed
))
}
}
return
nil
}
func
(
c
*
openAIWSCaptureConn
)
ReadMessage
(
ctx
context
.
Context
)
([]
byte
,
error
)
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
c
.
mu
.
Lock
()
if
c
.
closed
{
c
.
mu
.
Unlock
()
return
nil
,
errOpenAIWSConnClosed
}
if
len
(
c
.
events
)
==
0
{
c
.
mu
.
Unlock
()
return
nil
,
io
.
EOF
}
delay
:=
time
.
Duration
(
0
)
if
len
(
c
.
readDelays
)
>
0
{
delay
=
c
.
readDelays
[
0
]
c
.
readDelays
=
c
.
readDelays
[
1
:
]
}
event
:=
c
.
events
[
0
]
c
.
events
=
c
.
events
[
1
:
]
c
.
mu
.
Unlock
()
if
delay
>
0
{
timer
:=
time
.
NewTimer
(
delay
)
defer
timer
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
return
nil
,
ctx
.
Err
()
case
<-
timer
.
C
:
}
}
return
event
,
nil
}
func
(
c
*
openAIWSCaptureConn
)
Ping
(
ctx
context
.
Context
)
error
{
_
=
ctx
return
nil
}
func
(
c
*
openAIWSCaptureConn
)
Close
()
error
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
closed
=
true
return
nil
}
func
cloneMapStringAny
(
src
map
[
string
]
any
)
map
[
string
]
any
{
if
src
==
nil
{
return
nil
}
dst
:=
make
(
map
[
string
]
any
,
len
(
src
))
for
k
,
v
:=
range
src
{
dst
[
k
]
=
v
}
return
dst
}
backend/internal/service/openai_ws_pool.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"errors"
"fmt"
"math"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"golang.org/x/sync/errgroup"
)
const
(
openAIWSConnMaxAge
=
60
*
time
.
Minute
openAIWSConnHealthCheckIdle
=
90
*
time
.
Second
openAIWSConnHealthCheckTO
=
2
*
time
.
Second
openAIWSConnPrewarmExtraDelay
=
2
*
time
.
Second
openAIWSAcquireCleanupInterval
=
3
*
time
.
Second
openAIWSBackgroundPingInterval
=
30
*
time
.
Second
openAIWSBackgroundSweepTicker
=
30
*
time
.
Second
openAIWSPrewarmFailureWindow
=
30
*
time
.
Second
openAIWSPrewarmFailureSuppress
=
2
)
var
(
errOpenAIWSConnClosed
=
errors
.
New
(
"openai ws connection closed"
)
errOpenAIWSConnQueueFull
=
errors
.
New
(
"openai ws connection queue full"
)
errOpenAIWSPreferredConnUnavailable
=
errors
.
New
(
"openai ws preferred connection unavailable"
)
)
type
openAIWSDialError
struct
{
StatusCode
int
ResponseHeaders
http
.
Header
Err
error
}
func
(
e
*
openAIWSDialError
)
Error
()
string
{
if
e
==
nil
{
return
""
}
if
e
.
StatusCode
>
0
{
return
fmt
.
Sprintf
(
"openai ws dial failed: status=%d err=%v"
,
e
.
StatusCode
,
e
.
Err
)
}
return
fmt
.
Sprintf
(
"openai ws dial failed: %v"
,
e
.
Err
)
}
func
(
e
*
openAIWSDialError
)
Unwrap
()
error
{
if
e
==
nil
{
return
nil
}
return
e
.
Err
}
type
openAIWSAcquireRequest
struct
{
Account
*
Account
WSURL
string
Headers
http
.
Header
ProxyURL
string
PreferredConnID
string
// ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
ForceNewConn
bool
// ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。
ForcePreferredConn
bool
}
type
openAIWSConnLease
struct
{
pool
*
openAIWSConnPool
accountID
int64
conn
*
openAIWSConn
queueWait
time
.
Duration
connPick
time
.
Duration
reused
bool
released
atomic
.
Bool
}
func
(
l
*
openAIWSConnLease
)
activeConn
()
(
*
openAIWSConn
,
error
)
{
if
l
==
nil
||
l
.
conn
==
nil
{
return
nil
,
errOpenAIWSConnClosed
}
if
l
.
released
.
Load
()
{
return
nil
,
errOpenAIWSConnClosed
}
return
l
.
conn
,
nil
}
func
(
l
*
openAIWSConnLease
)
ConnID
()
string
{
if
l
==
nil
||
l
.
conn
==
nil
{
return
""
}
return
l
.
conn
.
id
}
func
(
l
*
openAIWSConnLease
)
QueueWaitDuration
()
time
.
Duration
{
if
l
==
nil
{
return
0
}
return
l
.
queueWait
}
func
(
l
*
openAIWSConnLease
)
ConnPickDuration
()
time
.
Duration
{
if
l
==
nil
{
return
0
}
return
l
.
connPick
}
func
(
l
*
openAIWSConnLease
)
Reused
()
bool
{
if
l
==
nil
{
return
false
}
return
l
.
reused
}
func
(
l
*
openAIWSConnLease
)
HandshakeHeader
(
name
string
)
string
{
if
l
==
nil
||
l
.
conn
==
nil
{
return
""
}
return
l
.
conn
.
handshakeHeader
(
name
)
}
func
(
l
*
openAIWSConnLease
)
IsPrewarmed
()
bool
{
if
l
==
nil
||
l
.
conn
==
nil
{
return
false
}
return
l
.
conn
.
isPrewarmed
()
}
func
(
l
*
openAIWSConnLease
)
MarkPrewarmed
()
{
if
l
==
nil
||
l
.
conn
==
nil
{
return
}
l
.
conn
.
markPrewarmed
()
}
func
(
l
*
openAIWSConnLease
)
WriteJSON
(
value
any
,
timeout
time
.
Duration
)
error
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
err
}
return
conn
.
writeJSONWithTimeout
(
context
.
Background
(),
value
,
timeout
)
}
func
(
l
*
openAIWSConnLease
)
WriteJSONWithContextTimeout
(
ctx
context
.
Context
,
value
any
,
timeout
time
.
Duration
)
error
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
err
}
return
conn
.
writeJSONWithTimeout
(
ctx
,
value
,
timeout
)
}
func
(
l
*
openAIWSConnLease
)
WriteJSONContext
(
ctx
context
.
Context
,
value
any
)
error
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
err
}
return
conn
.
writeJSON
(
value
,
ctx
)
}
func
(
l
*
openAIWSConnLease
)
ReadMessage
(
timeout
time
.
Duration
)
([]
byte
,
error
)
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
nil
,
err
}
return
conn
.
readMessageWithTimeout
(
timeout
)
}
func
(
l
*
openAIWSConnLease
)
ReadMessageContext
(
ctx
context
.
Context
)
([]
byte
,
error
)
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
nil
,
err
}
return
conn
.
readMessage
(
ctx
)
}
func
(
l
*
openAIWSConnLease
)
ReadMessageWithContextTimeout
(
ctx
context
.
Context
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
nil
,
err
}
return
conn
.
readMessageWithContextTimeout
(
ctx
,
timeout
)
}
func
(
l
*
openAIWSConnLease
)
PingWithTimeout
(
timeout
time
.
Duration
)
error
{
conn
,
err
:=
l
.
activeConn
()
if
err
!=
nil
{
return
err
}
return
conn
.
pingWithTimeout
(
timeout
)
}
func
(
l
*
openAIWSConnLease
)
MarkBroken
()
{
if
l
==
nil
||
l
.
pool
==
nil
||
l
.
conn
==
nil
||
l
.
released
.
Load
()
{
return
}
l
.
pool
.
evictConn
(
l
.
accountID
,
l
.
conn
.
id
)
}
func
(
l
*
openAIWSConnLease
)
Release
()
{
if
l
==
nil
||
l
.
conn
==
nil
{
return
}
if
!
l
.
released
.
CompareAndSwap
(
false
,
true
)
{
return
}
l
.
conn
.
release
()
}
type
openAIWSConn
struct
{
id
string
ws
openAIWSClientConn
handshakeHeaders
http
.
Header
leaseCh
chan
struct
{}
closedCh
chan
struct
{}
closeOnce
sync
.
Once
readMu
sync
.
Mutex
writeMu
sync
.
Mutex
waiters
atomic
.
Int32
createdAtNano
atomic
.
Int64
lastUsedNano
atomic
.
Int64
prewarmed
atomic
.
Bool
}
func
newOpenAIWSConn
(
id
string
,
_
int64
,
ws
openAIWSClientConn
,
handshakeHeaders
http
.
Header
)
*
openAIWSConn
{
now
:=
time
.
Now
()
conn
:=
&
openAIWSConn
{
id
:
id
,
ws
:
ws
,
handshakeHeaders
:
cloneHeader
(
handshakeHeaders
),
leaseCh
:
make
(
chan
struct
{},
1
),
closedCh
:
make
(
chan
struct
{}),
}
conn
.
leaseCh
<-
struct
{}{}
conn
.
createdAtNano
.
Store
(
now
.
UnixNano
())
conn
.
lastUsedNano
.
Store
(
now
.
UnixNano
())
return
conn
}
func
(
c
*
openAIWSConn
)
tryAcquire
()
bool
{
if
c
==
nil
{
return
false
}
select
{
case
<-
c
.
closedCh
:
return
false
default
:
}
select
{
case
<-
c
.
leaseCh
:
select
{
case
<-
c
.
closedCh
:
c
.
release
()
return
false
default
:
}
return
true
default
:
return
false
}
}
func
(
c
*
openAIWSConn
)
acquire
(
ctx
context
.
Context
)
error
{
if
c
==
nil
{
return
errOpenAIWSConnClosed
}
for
{
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
case
<-
c
.
closedCh
:
return
errOpenAIWSConnClosed
case
<-
c
.
leaseCh
:
select
{
case
<-
c
.
closedCh
:
c
.
release
()
return
errOpenAIWSConnClosed
default
:
}
return
nil
}
}
}
func
(
c
*
openAIWSConn
)
release
()
{
if
c
==
nil
{
return
}
select
{
case
c
.
leaseCh
<-
struct
{}{}
:
default
:
}
c
.
touch
()
}
func
(
c
*
openAIWSConn
)
close
()
{
if
c
==
nil
{
return
}
c
.
closeOnce
.
Do
(
func
()
{
close
(
c
.
closedCh
)
if
c
.
ws
!=
nil
{
_
=
c
.
ws
.
Close
()
}
select
{
case
c
.
leaseCh
<-
struct
{}{}
:
default
:
}
})
}
func
(
c
*
openAIWSConn
)
writeJSONWithTimeout
(
parent
context
.
Context
,
value
any
,
timeout
time
.
Duration
)
error
{
if
c
==
nil
{
return
errOpenAIWSConnClosed
}
select
{
case
<-
c
.
closedCh
:
return
errOpenAIWSConnClosed
default
:
}
writeCtx
:=
parent
if
writeCtx
==
nil
{
writeCtx
=
context
.
Background
()
}
if
timeout
<=
0
{
return
c
.
writeJSON
(
value
,
writeCtx
)
}
var
cancel
context
.
CancelFunc
writeCtx
,
cancel
=
context
.
WithTimeout
(
writeCtx
,
timeout
)
defer
cancel
()
return
c
.
writeJSON
(
value
,
writeCtx
)
}
func
(
c
*
openAIWSConn
)
writeJSON
(
value
any
,
writeCtx
context
.
Context
)
error
{
c
.
writeMu
.
Lock
()
defer
c
.
writeMu
.
Unlock
()
if
c
.
ws
==
nil
{
return
errOpenAIWSConnClosed
}
if
writeCtx
==
nil
{
writeCtx
=
context
.
Background
()
}
if
err
:=
c
.
ws
.
WriteJSON
(
writeCtx
,
value
);
err
!=
nil
{
return
err
}
c
.
touch
()
return
nil
}
func
(
c
*
openAIWSConn
)
readMessageWithTimeout
(
timeout
time
.
Duration
)
([]
byte
,
error
)
{
return
c
.
readMessageWithContextTimeout
(
context
.
Background
(),
timeout
)
}
func
(
c
*
openAIWSConn
)
readMessageWithContextTimeout
(
parent
context
.
Context
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
if
c
==
nil
{
return
nil
,
errOpenAIWSConnClosed
}
select
{
case
<-
c
.
closedCh
:
return
nil
,
errOpenAIWSConnClosed
default
:
}
if
parent
==
nil
{
parent
=
context
.
Background
()
}
if
timeout
<=
0
{
return
c
.
readMessage
(
parent
)
}
readCtx
,
cancel
:=
context
.
WithTimeout
(
parent
,
timeout
)
defer
cancel
()
return
c
.
readMessage
(
readCtx
)
}
func
(
c
*
openAIWSConn
)
readMessage
(
readCtx
context
.
Context
)
([]
byte
,
error
)
{
c
.
readMu
.
Lock
()
defer
c
.
readMu
.
Unlock
()
if
c
.
ws
==
nil
{
return
nil
,
errOpenAIWSConnClosed
}
if
readCtx
==
nil
{
readCtx
=
context
.
Background
()
}
payload
,
err
:=
c
.
ws
.
ReadMessage
(
readCtx
)
if
err
!=
nil
{
return
nil
,
err
}
c
.
touch
()
return
payload
,
nil
}
func
(
c
*
openAIWSConn
)
pingWithTimeout
(
timeout
time
.
Duration
)
error
{
if
c
==
nil
{
return
errOpenAIWSConnClosed
}
select
{
case
<-
c
.
closedCh
:
return
errOpenAIWSConnClosed
default
:
}
c
.
writeMu
.
Lock
()
defer
c
.
writeMu
.
Unlock
()
if
c
.
ws
==
nil
{
return
errOpenAIWSConnClosed
}
if
timeout
<=
0
{
timeout
=
openAIWSConnHealthCheckTO
}
pingCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
defer
cancel
()
if
err
:=
c
.
ws
.
Ping
(
pingCtx
);
err
!=
nil
{
return
err
}
return
nil
}
func
(
c
*
openAIWSConn
)
touch
()
{
if
c
==
nil
{
return
}
c
.
lastUsedNano
.
Store
(
time
.
Now
()
.
UnixNano
())
}
func
(
c
*
openAIWSConn
)
createdAt
()
time
.
Time
{
if
c
==
nil
{
return
time
.
Time
{}
}
nano
:=
c
.
createdAtNano
.
Load
()
if
nano
<=
0
{
return
time
.
Time
{}
}
return
time
.
Unix
(
0
,
nano
)
}
func
(
c
*
openAIWSConn
)
lastUsedAt
()
time
.
Time
{
if
c
==
nil
{
return
time
.
Time
{}
}
nano
:=
c
.
lastUsedNano
.
Load
()
if
nano
<=
0
{
return
time
.
Time
{}
}
return
time
.
Unix
(
0
,
nano
)
}
func
(
c
*
openAIWSConn
)
idleDuration
(
now
time
.
Time
)
time
.
Duration
{
if
c
==
nil
{
return
0
}
last
:=
c
.
lastUsedAt
()
if
last
.
IsZero
()
{
return
0
}
return
now
.
Sub
(
last
)
}
func
(
c
*
openAIWSConn
)
age
(
now
time
.
Time
)
time
.
Duration
{
if
c
==
nil
{
return
0
}
created
:=
c
.
createdAt
()
if
created
.
IsZero
()
{
return
0
}
return
now
.
Sub
(
created
)
}
func
(
c
*
openAIWSConn
)
isLeased
()
bool
{
if
c
==
nil
{
return
false
}
return
len
(
c
.
leaseCh
)
==
0
}
func
(
c
*
openAIWSConn
)
handshakeHeader
(
name
string
)
string
{
if
c
==
nil
||
c
.
handshakeHeaders
==
nil
{
return
""
}
return
strings
.
TrimSpace
(
c
.
handshakeHeaders
.
Get
(
strings
.
TrimSpace
(
name
)))
}
func
(
c
*
openAIWSConn
)
isPrewarmed
()
bool
{
if
c
==
nil
{
return
false
}
return
c
.
prewarmed
.
Load
()
}
func
(
c
*
openAIWSConn
)
markPrewarmed
()
{
if
c
==
nil
{
return
}
c
.
prewarmed
.
Store
(
true
)
}
type
openAIWSAccountPool
struct
{
mu
sync
.
Mutex
conns
map
[
string
]
*
openAIWSConn
pinnedConns
map
[
string
]
int
creating
int
lastCleanupAt
time
.
Time
lastAcquire
*
openAIWSAcquireRequest
prewarmActive
bool
prewarmUntil
time
.
Time
prewarmFails
int
prewarmFailAt
time
.
Time
}
type
OpenAIWSPoolMetricsSnapshot
struct
{
AcquireTotal
int64
AcquireReuseTotal
int64
AcquireCreateTotal
int64
AcquireQueueWaitTotal
int64
AcquireQueueWaitMsTotal
int64
ConnPickTotal
int64
ConnPickMsTotal
int64
ScaleUpTotal
int64
ScaleDownTotal
int64
}
type
openAIWSPoolMetrics
struct
{
acquireTotal
atomic
.
Int64
acquireReuseTotal
atomic
.
Int64
acquireCreateTotal
atomic
.
Int64
acquireQueueWaitTotal
atomic
.
Int64
acquireQueueWaitMs
atomic
.
Int64
connPickTotal
atomic
.
Int64
connPickMs
atomic
.
Int64
scaleUpTotal
atomic
.
Int64
scaleDownTotal
atomic
.
Int64
}
type
openAIWSConnPool
struct
{
cfg
*
config
.
Config
// 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
clientDialer
openAIWSClientDialer
accounts
sync
.
Map
// key: int64(accountID), value: *openAIWSAccountPool
seq
atomic
.
Uint64
metrics
openAIWSPoolMetrics
workerStopCh
chan
struct
{}
workerWg
sync
.
WaitGroup
closeOnce
sync
.
Once
}
func
newOpenAIWSConnPool
(
cfg
*
config
.
Config
)
*
openAIWSConnPool
{
pool
:=
&
openAIWSConnPool
{
cfg
:
cfg
,
clientDialer
:
newDefaultOpenAIWSClientDialer
(),
workerStopCh
:
make
(
chan
struct
{}),
}
pool
.
startBackgroundWorkers
()
return
pool
}
func
(
p
*
openAIWSConnPool
)
SnapshotMetrics
()
OpenAIWSPoolMetricsSnapshot
{
if
p
==
nil
{
return
OpenAIWSPoolMetricsSnapshot
{}
}
return
OpenAIWSPoolMetricsSnapshot
{
AcquireTotal
:
p
.
metrics
.
acquireTotal
.
Load
(),
AcquireReuseTotal
:
p
.
metrics
.
acquireReuseTotal
.
Load
(),
AcquireCreateTotal
:
p
.
metrics
.
acquireCreateTotal
.
Load
(),
AcquireQueueWaitTotal
:
p
.
metrics
.
acquireQueueWaitTotal
.
Load
(),
AcquireQueueWaitMsTotal
:
p
.
metrics
.
acquireQueueWaitMs
.
Load
(),
ConnPickTotal
:
p
.
metrics
.
connPickTotal
.
Load
(),
ConnPickMsTotal
:
p
.
metrics
.
connPickMs
.
Load
(),
ScaleUpTotal
:
p
.
metrics
.
scaleUpTotal
.
Load
(),
ScaleDownTotal
:
p
.
metrics
.
scaleDownTotal
.
Load
(),
}
}
func
(
p
*
openAIWSConnPool
)
SnapshotTransportMetrics
()
OpenAIWSTransportMetricsSnapshot
{
if
p
==
nil
{
return
OpenAIWSTransportMetricsSnapshot
{}
}
if
dialer
,
ok
:=
p
.
clientDialer
.
(
openAIWSTransportMetricsDialer
);
ok
{
return
dialer
.
SnapshotTransportMetrics
()
}
return
OpenAIWSTransportMetricsSnapshot
{}
}
func
(
p
*
openAIWSConnPool
)
setClientDialerForTest
(
dialer
openAIWSClientDialer
)
{
if
p
==
nil
||
dialer
==
nil
{
return
}
p
.
clientDialer
=
dialer
}
// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
func
(
p
*
openAIWSConnPool
)
Close
()
{
if
p
==
nil
{
return
}
p
.
closeOnce
.
Do
(
func
()
{
if
p
.
workerStopCh
!=
nil
{
close
(
p
.
workerStopCh
)
}
p
.
workerWg
.
Wait
()
// 遍历所有账户池,关闭全部空闲连接。
p
.
accounts
.
Range
(
func
(
key
,
value
any
)
bool
{
ap
,
ok
:=
value
.
(
*
openAIWSAccountPool
)
if
!
ok
||
ap
==
nil
{
return
true
}
ap
.
mu
.
Lock
()
for
_
,
conn
:=
range
ap
.
conns
{
if
conn
!=
nil
&&
!
conn
.
isLeased
()
{
conn
.
close
()
}
}
ap
.
mu
.
Unlock
()
return
true
})
})
}
func
(
p
*
openAIWSConnPool
)
startBackgroundWorkers
()
{
if
p
==
nil
||
p
.
workerStopCh
==
nil
{
return
}
p
.
workerWg
.
Add
(
2
)
go
func
()
{
defer
p
.
workerWg
.
Done
()
p
.
runBackgroundPingWorker
()
}()
go
func
()
{
defer
p
.
workerWg
.
Done
()
p
.
runBackgroundCleanupWorker
()
}()
}
type
openAIWSIdlePingCandidate
struct
{
accountID
int64
conn
*
openAIWSConn
}
func
(
p
*
openAIWSConnPool
)
runBackgroundPingWorker
()
{
if
p
==
nil
{
return
}
ticker
:=
time
.
NewTicker
(
openAIWSBackgroundPingInterval
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ticker
.
C
:
p
.
runBackgroundPingSweep
()
case
<-
p
.
workerStopCh
:
return
}
}
}
func
(
p
*
openAIWSConnPool
)
runBackgroundPingSweep
()
{
if
p
==
nil
{
return
}
candidates
:=
p
.
snapshotIdleConnsForPing
()
var
g
errgroup
.
Group
g
.
SetLimit
(
10
)
for
_
,
item
:=
range
candidates
{
item
:=
item
if
item
.
conn
==
nil
||
item
.
conn
.
isLeased
()
||
item
.
conn
.
waiters
.
Load
()
>
0
{
continue
}
g
.
Go
(
func
()
error
{
if
err
:=
item
.
conn
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
p
.
evictConn
(
item
.
accountID
,
item
.
conn
.
id
)
}
return
nil
})
}
_
=
g
.
Wait
()
}
func
(
p
*
openAIWSConnPool
)
snapshotIdleConnsForPing
()
[]
openAIWSIdlePingCandidate
{
if
p
==
nil
{
return
nil
}
candidates
:=
make
([]
openAIWSIdlePingCandidate
,
0
)
p
.
accounts
.
Range
(
func
(
key
,
value
any
)
bool
{
accountID
,
ok
:=
key
.
(
int64
)
if
!
ok
||
accountID
<=
0
{
return
true
}
ap
,
ok
:=
value
.
(
*
openAIWSAccountPool
)
if
!
ok
||
ap
==
nil
{
return
true
}
ap
.
mu
.
Lock
()
for
_
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
||
conn
.
isLeased
()
||
conn
.
waiters
.
Load
()
>
0
{
continue
}
candidates
=
append
(
candidates
,
openAIWSIdlePingCandidate
{
accountID
:
accountID
,
conn
:
conn
,
})
}
ap
.
mu
.
Unlock
()
return
true
})
return
candidates
}
func
(
p
*
openAIWSConnPool
)
runBackgroundCleanupWorker
()
{
if
p
==
nil
{
return
}
ticker
:=
time
.
NewTicker
(
openAIWSBackgroundSweepTicker
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ticker
.
C
:
p
.
runBackgroundCleanupSweep
(
time
.
Now
())
case
<-
p
.
workerStopCh
:
return
}
}
}
func
(
p
*
openAIWSConnPool
)
runBackgroundCleanupSweep
(
now
time
.
Time
)
{
if
p
==
nil
{
return
}
type
cleanupResult
struct
{
evicted
[]
*
openAIWSConn
}
results
:=
make
([]
cleanupResult
,
0
)
p
.
accounts
.
Range
(
func
(
_
any
,
value
any
)
bool
{
ap
,
ok
:=
value
.
(
*
openAIWSAccountPool
)
if
!
ok
||
ap
==
nil
{
return
true
}
maxConns
:=
p
.
maxConnsHardCap
()
ap
.
mu
.
Lock
()
if
ap
.
lastAcquire
!=
nil
&&
ap
.
lastAcquire
.
Account
!=
nil
{
maxConns
=
p
.
effectiveMaxConnsByAccount
(
ap
.
lastAcquire
.
Account
)
}
evicted
:=
p
.
cleanupAccountLocked
(
ap
,
now
,
maxConns
)
ap
.
lastCleanupAt
=
now
ap
.
mu
.
Unlock
()
if
len
(
evicted
)
>
0
{
results
=
append
(
results
,
cleanupResult
{
evicted
:
evicted
})
}
return
true
})
for
_
,
result
:=
range
results
{
closeOpenAIWSConns
(
result
.
evicted
)
}
}
func
(
p
*
openAIWSConnPool
)
Acquire
(
ctx
context
.
Context
,
req
openAIWSAcquireRequest
)
(
*
openAIWSConnLease
,
error
)
{
if
p
!=
nil
{
p
.
metrics
.
acquireTotal
.
Add
(
1
)
}
return
p
.
acquire
(
ctx
,
cloneOpenAIWSAcquireRequest
(
req
),
0
)
}
func
(
p
*
openAIWSConnPool
)
acquire
(
ctx
context
.
Context
,
req
openAIWSAcquireRequest
,
retry
int
)
(
*
openAIWSConnLease
,
error
)
{
if
p
==
nil
||
req
.
Account
==
nil
||
req
.
Account
.
ID
<=
0
{
return
nil
,
errors
.
New
(
"invalid ws acquire request"
)
}
if
stringsTrim
(
req
.
WSURL
)
==
""
{
return
nil
,
errors
.
New
(
"ws url is empty"
)
}
accountID
:=
req
.
Account
.
ID
effectiveMaxConns
:=
p
.
effectiveMaxConnsByAccount
(
req
.
Account
)
if
effectiveMaxConns
<=
0
{
return
nil
,
errOpenAIWSConnQueueFull
}
var
evicted
[]
*
openAIWSConn
ap
:=
p
.
getOrCreateAccountPool
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
lastAcquire
=
cloneOpenAIWSAcquireRequestPtr
(
&
req
)
now
:=
time
.
Now
()
if
ap
.
lastCleanupAt
.
IsZero
()
||
now
.
Sub
(
ap
.
lastCleanupAt
)
>=
openAIWSAcquireCleanupInterval
{
evicted
=
p
.
cleanupAccountLocked
(
ap
,
now
,
effectiveMaxConns
)
ap
.
lastCleanupAt
=
now
}
pickStartedAt
:=
time
.
Now
()
allowReuse
:=
!
req
.
ForceNewConn
preferredConnID
:=
stringsTrim
(
req
.
PreferredConnID
)
forcePreferredConn
:=
allowReuse
&&
req
.
ForcePreferredConn
if
allowReuse
{
if
forcePreferredConn
{
if
preferredConnID
==
""
{
p
.
recordConnPickDuration
(
time
.
Since
(
pickStartedAt
))
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
return
nil
,
errOpenAIWSPreferredConnUnavailable
}
preferredConn
,
ok
:=
ap
.
conns
[
preferredConnID
]
if
!
ok
||
preferredConn
==
nil
{
p
.
recordConnPickDuration
(
time
.
Since
(
pickStartedAt
))
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
return
nil
,
errOpenAIWSPreferredConnUnavailable
}
if
preferredConn
.
tryAcquire
()
{
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
if
p
.
shouldHealthCheckConn
(
preferredConn
)
{
if
err
:=
preferredConn
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
preferredConn
.
close
()
p
.
evictConn
(
accountID
,
preferredConn
.
id
)
if
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
}
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
preferredConn
,
connPick
:
connPick
,
reused
:
true
,
}
p
.
metrics
.
acquireReuseTotal
.
Add
(
1
)
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
if
int
(
preferredConn
.
waiters
.
Load
())
>=
p
.
queueLimitPerConn
()
{
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
return
nil
,
errOpenAIWSConnQueueFull
}
preferredConn
.
waiters
.
Add
(
1
)
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
defer
preferredConn
.
waiters
.
Add
(
-
1
)
waitStart
:=
time
.
Now
()
p
.
metrics
.
acquireQueueWaitTotal
.
Add
(
1
)
if
err
:=
preferredConn
.
acquire
(
ctx
);
err
!=
nil
{
if
errors
.
Is
(
err
,
errOpenAIWSConnClosed
)
&&
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
if
p
.
shouldHealthCheckConn
(
preferredConn
)
{
if
err
:=
preferredConn
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
preferredConn
.
release
()
preferredConn
.
close
()
p
.
evictConn
(
accountID
,
preferredConn
.
id
)
if
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
}
queueWait
:=
time
.
Since
(
waitStart
)
p
.
metrics
.
acquireQueueWaitMs
.
Add
(
queueWait
.
Milliseconds
())
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
preferredConn
,
queueWait
:
queueWait
,
connPick
:
connPick
,
reused
:
true
,
}
p
.
metrics
.
acquireReuseTotal
.
Add
(
1
)
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
if
preferredConnID
!=
""
{
if
conn
,
ok
:=
ap
.
conns
[
preferredConnID
];
ok
&&
conn
.
tryAcquire
()
{
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
if
p
.
shouldHealthCheckConn
(
conn
)
{
if
err
:=
conn
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
conn
.
close
()
p
.
evictConn
(
accountID
,
conn
.
id
)
if
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
}
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
conn
,
connPick
:
connPick
,
reused
:
true
}
p
.
metrics
.
acquireReuseTotal
.
Add
(
1
)
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
}
best
:=
p
.
pickLeastBusyConnLocked
(
ap
,
""
)
if
best
!=
nil
&&
best
.
tryAcquire
()
{
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
if
p
.
shouldHealthCheckConn
(
best
)
{
if
err
:=
best
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
best
.
close
()
p
.
evictConn
(
accountID
,
best
.
id
)
if
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
}
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
best
,
connPick
:
connPick
,
reused
:
true
}
p
.
metrics
.
acquireReuseTotal
.
Add
(
1
)
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
for
_
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
||
conn
==
best
{
continue
}
if
conn
.
tryAcquire
()
{
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
if
p
.
shouldHealthCheckConn
(
conn
)
{
if
err
:=
conn
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
conn
.
close
()
p
.
evictConn
(
accountID
,
conn
.
id
)
if
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
}
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
conn
,
connPick
:
connPick
,
reused
:
true
}
p
.
metrics
.
acquireReuseTotal
.
Add
(
1
)
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
}
}
if
req
.
ForceNewConn
&&
len
(
ap
.
conns
)
+
ap
.
creating
>=
effectiveMaxConns
{
if
idle
:=
p
.
pickOldestIdleConnLocked
(
ap
);
idle
!=
nil
{
delete
(
ap
.
conns
,
idle
.
id
)
evicted
=
append
(
evicted
,
idle
)
p
.
metrics
.
scaleDownTotal
.
Add
(
1
)
}
}
if
len
(
ap
.
conns
)
+
ap
.
creating
<
effectiveMaxConns
{
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
ap
.
creating
++
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
conn
,
dialErr
:=
p
.
dialConn
(
ctx
,
req
)
ap
=
p
.
getOrCreateAccountPool
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
creating
--
if
dialErr
!=
nil
{
ap
.
prewarmFails
++
ap
.
prewarmFailAt
=
time
.
Now
()
ap
.
mu
.
Unlock
()
return
nil
,
dialErr
}
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
prewarmFails
=
0
ap
.
prewarmFailAt
=
time
.
Time
{}
ap
.
mu
.
Unlock
()
p
.
metrics
.
acquireCreateTotal
.
Add
(
1
)
if
!
conn
.
tryAcquire
()
{
if
err
:=
conn
.
acquire
(
ctx
);
err
!=
nil
{
conn
.
close
()
p
.
evictConn
(
accountID
,
conn
.
id
)
return
nil
,
err
}
}
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
conn
,
connPick
:
connPick
}
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
if
req
.
ForceNewConn
{
p
.
recordConnPickDuration
(
time
.
Since
(
pickStartedAt
))
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
return
nil
,
errOpenAIWSConnQueueFull
}
target
:=
p
.
pickLeastBusyConnLocked
(
ap
,
req
.
PreferredConnID
)
connPick
:=
time
.
Since
(
pickStartedAt
)
p
.
recordConnPickDuration
(
connPick
)
if
target
==
nil
{
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
return
nil
,
errOpenAIWSConnClosed
}
if
int
(
target
.
waiters
.
Load
())
>=
p
.
queueLimitPerConn
()
{
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
return
nil
,
errOpenAIWSConnQueueFull
}
target
.
waiters
.
Add
(
1
)
ap
.
mu
.
Unlock
()
closeOpenAIWSConns
(
evicted
)
defer
target
.
waiters
.
Add
(
-
1
)
waitStart
:=
time
.
Now
()
p
.
metrics
.
acquireQueueWaitTotal
.
Add
(
1
)
if
err
:=
target
.
acquire
(
ctx
);
err
!=
nil
{
if
errors
.
Is
(
err
,
errOpenAIWSConnClosed
)
&&
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
if
p
.
shouldHealthCheckConn
(
target
)
{
if
err
:=
target
.
pingWithTimeout
(
openAIWSConnHealthCheckTO
);
err
!=
nil
{
target
.
release
()
target
.
close
()
p
.
evictConn
(
accountID
,
target
.
id
)
if
retry
<
1
{
return
p
.
acquire
(
ctx
,
req
,
retry
+
1
)
}
return
nil
,
err
}
}
queueWait
:=
time
.
Since
(
waitStart
)
p
.
metrics
.
acquireQueueWaitMs
.
Add
(
queueWait
.
Milliseconds
())
lease
:=
&
openAIWSConnLease
{
pool
:
p
,
accountID
:
accountID
,
conn
:
target
,
queueWait
:
queueWait
,
connPick
:
connPick
,
reused
:
true
}
p
.
metrics
.
acquireReuseTotal
.
Add
(
1
)
p
.
ensureTargetIdleAsync
(
accountID
)
return
lease
,
nil
}
func
(
p
*
openAIWSConnPool
)
recordConnPickDuration
(
duration
time
.
Duration
)
{
if
p
==
nil
{
return
}
if
duration
<
0
{
duration
=
0
}
p
.
metrics
.
connPickTotal
.
Add
(
1
)
p
.
metrics
.
connPickMs
.
Add
(
duration
.
Milliseconds
())
}
func
(
p
*
openAIWSConnPool
)
pickOldestIdleConnLocked
(
ap
*
openAIWSAccountPool
)
*
openAIWSConn
{
if
ap
==
nil
||
len
(
ap
.
conns
)
==
0
{
return
nil
}
var
oldest
*
openAIWSConn
for
_
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
||
conn
.
isLeased
()
||
conn
.
waiters
.
Load
()
>
0
||
p
.
isConnPinnedLocked
(
ap
,
conn
.
id
)
{
continue
}
if
oldest
==
nil
||
conn
.
lastUsedAt
()
.
Before
(
oldest
.
lastUsedAt
())
{
oldest
=
conn
}
}
return
oldest
}
func
(
p
*
openAIWSConnPool
)
getOrCreateAccountPool
(
accountID
int64
)
*
openAIWSAccountPool
{
if
p
==
nil
||
accountID
<=
0
{
return
nil
}
if
existing
,
ok
:=
p
.
accounts
.
Load
(
accountID
);
ok
{
if
ap
,
typed
:=
existing
.
(
*
openAIWSAccountPool
);
typed
&&
ap
!=
nil
{
return
ap
}
}
ap
:=
&
openAIWSAccountPool
{
conns
:
make
(
map
[
string
]
*
openAIWSConn
),
pinnedConns
:
make
(
map
[
string
]
int
),
}
actual
,
_
:=
p
.
accounts
.
LoadOrStore
(
accountID
,
ap
)
if
typed
,
ok
:=
actual
.
(
*
openAIWSAccountPool
);
ok
&&
typed
!=
nil
{
return
typed
}
return
ap
}
// ensureAccountPoolLocked 兼容旧调用。
func
(
p
*
openAIWSConnPool
)
ensureAccountPoolLocked
(
accountID
int64
)
*
openAIWSAccountPool
{
return
p
.
getOrCreateAccountPool
(
accountID
)
}
func
(
p
*
openAIWSConnPool
)
getAccountPool
(
accountID
int64
)
(
*
openAIWSAccountPool
,
bool
)
{
if
p
==
nil
||
accountID
<=
0
{
return
nil
,
false
}
value
,
ok
:=
p
.
accounts
.
Load
(
accountID
)
if
!
ok
||
value
==
nil
{
return
nil
,
false
}
ap
,
typed
:=
value
.
(
*
openAIWSAccountPool
)
return
ap
,
typed
&&
ap
!=
nil
}
func
(
p
*
openAIWSConnPool
)
isConnPinnedLocked
(
ap
*
openAIWSAccountPool
,
connID
string
)
bool
{
if
ap
==
nil
||
connID
==
""
||
len
(
ap
.
pinnedConns
)
==
0
{
return
false
}
return
ap
.
pinnedConns
[
connID
]
>
0
}
func
(
p
*
openAIWSConnPool
)
cleanupAccountLocked
(
ap
*
openAIWSAccountPool
,
now
time
.
Time
,
maxConns
int
)
[]
*
openAIWSConn
{
if
ap
==
nil
{
return
nil
}
maxAge
:=
p
.
maxConnAge
()
evicted
:=
make
([]
*
openAIWSConn
,
0
)
for
id
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
{
delete
(
ap
.
conns
,
id
)
if
len
(
ap
.
pinnedConns
)
>
0
{
delete
(
ap
.
pinnedConns
,
id
)
}
continue
}
select
{
case
<-
conn
.
closedCh
:
delete
(
ap
.
conns
,
id
)
if
len
(
ap
.
pinnedConns
)
>
0
{
delete
(
ap
.
pinnedConns
,
id
)
}
evicted
=
append
(
evicted
,
conn
)
continue
default
:
}
if
p
.
isConnPinnedLocked
(
ap
,
id
)
{
continue
}
if
maxAge
>
0
&&
!
conn
.
isLeased
()
&&
conn
.
age
(
now
)
>
maxAge
{
delete
(
ap
.
conns
,
id
)
if
len
(
ap
.
pinnedConns
)
>
0
{
delete
(
ap
.
pinnedConns
,
id
)
}
evicted
=
append
(
evicted
,
conn
)
}
}
if
maxConns
<=
0
{
maxConns
=
p
.
maxConnsHardCap
()
}
maxIdle
:=
p
.
maxIdlePerAccount
()
if
maxIdle
<
0
||
maxIdle
>
maxConns
{
maxIdle
=
maxConns
}
if
maxIdle
>=
0
&&
len
(
ap
.
conns
)
>
maxIdle
{
idleConns
:=
make
([]
*
openAIWSConn
,
0
,
len
(
ap
.
conns
))
for
id
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
{
delete
(
ap
.
conns
,
id
)
if
len
(
ap
.
pinnedConns
)
>
0
{
delete
(
ap
.
pinnedConns
,
id
)
}
continue
}
// 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
if
conn
.
isLeased
()
||
conn
.
waiters
.
Load
()
>
0
||
p
.
isConnPinnedLocked
(
ap
,
conn
.
id
)
{
continue
}
idleConns
=
append
(
idleConns
,
conn
)
}
sort
.
SliceStable
(
idleConns
,
func
(
i
,
j
int
)
bool
{
return
idleConns
[
i
]
.
lastUsedAt
()
.
Before
(
idleConns
[
j
]
.
lastUsedAt
())
})
redundant
:=
len
(
ap
.
conns
)
-
maxIdle
if
redundant
>
len
(
idleConns
)
{
redundant
=
len
(
idleConns
)
}
for
i
:=
0
;
i
<
redundant
;
i
++
{
conn
:=
idleConns
[
i
]
delete
(
ap
.
conns
,
conn
.
id
)
if
len
(
ap
.
pinnedConns
)
>
0
{
delete
(
ap
.
pinnedConns
,
conn
.
id
)
}
evicted
=
append
(
evicted
,
conn
)
}
if
redundant
>
0
{
p
.
metrics
.
scaleDownTotal
.
Add
(
int64
(
redundant
))
}
}
return
evicted
}
func
(
p
*
openAIWSConnPool
)
pickLeastBusyConnLocked
(
ap
*
openAIWSAccountPool
,
preferredConnID
string
)
*
openAIWSConn
{
if
ap
==
nil
||
len
(
ap
.
conns
)
==
0
{
return
nil
}
preferredConnID
=
stringsTrim
(
preferredConnID
)
if
preferredConnID
!=
""
{
if
conn
,
ok
:=
ap
.
conns
[
preferredConnID
];
ok
{
return
conn
}
}
var
best
*
openAIWSConn
var
bestWaiters
int32
var
bestLastUsed
time
.
Time
for
_
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
{
continue
}
waiters
:=
conn
.
waiters
.
Load
()
lastUsed
:=
conn
.
lastUsedAt
()
if
best
==
nil
||
waiters
<
bestWaiters
||
(
waiters
==
bestWaiters
&&
lastUsed
.
Before
(
bestLastUsed
))
{
best
=
conn
bestWaiters
=
waiters
bestLastUsed
=
lastUsed
}
}
return
best
}
func
accountPoolLoadLocked
(
ap
*
openAIWSAccountPool
)
(
inflight
int
,
waiters
int
)
{
if
ap
==
nil
{
return
0
,
0
}
for
_
,
conn
:=
range
ap
.
conns
{
if
conn
==
nil
{
continue
}
if
conn
.
isLeased
()
{
inflight
++
}
waiters
+=
int
(
conn
.
waiters
.
Load
())
}
return
inflight
,
waiters
}
// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
func
(
p
*
openAIWSConnPool
)
AccountPoolLoad
(
accountID
int64
)
(
inflight
int
,
waiters
int
,
conns
int
)
{
if
p
==
nil
||
accountID
<=
0
{
return
0
,
0
,
0
}
ap
,
ok
:=
p
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
0
,
0
,
0
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
inflight
,
waiters
=
accountPoolLoadLocked
(
ap
)
return
inflight
,
waiters
,
len
(
ap
.
conns
)
}
func
(
p
*
openAIWSConnPool
)
ensureTargetIdleAsync
(
accountID
int64
)
{
if
p
==
nil
||
accountID
<=
0
{
return
}
var
req
openAIWSAcquireRequest
need
:=
0
ap
,
ok
:=
p
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
if
ap
.
lastAcquire
==
nil
{
return
}
if
ap
.
prewarmActive
{
return
}
now
:=
time
.
Now
()
if
!
ap
.
prewarmUntil
.
IsZero
()
&&
now
.
Before
(
ap
.
prewarmUntil
)
{
return
}
if
p
.
shouldSuppressPrewarmLocked
(
ap
,
now
)
{
return
}
effectiveMaxConns
:=
p
.
maxConnsHardCap
()
if
ap
.
lastAcquire
!=
nil
&&
ap
.
lastAcquire
.
Account
!=
nil
{
effectiveMaxConns
=
p
.
effectiveMaxConnsByAccount
(
ap
.
lastAcquire
.
Account
)
}
target
:=
p
.
targetConnCountLocked
(
ap
,
effectiveMaxConns
)
current
:=
len
(
ap
.
conns
)
+
ap
.
creating
if
current
>=
target
{
return
}
need
=
target
-
current
if
need
<=
0
{
return
}
req
=
cloneOpenAIWSAcquireRequest
(
*
ap
.
lastAcquire
)
ap
.
prewarmActive
=
true
if
cooldown
:=
p
.
prewarmCooldown
();
cooldown
>
0
{
ap
.
prewarmUntil
=
now
.
Add
(
cooldown
)
}
ap
.
creating
+=
need
p
.
metrics
.
scaleUpTotal
.
Add
(
int64
(
need
))
go
p
.
prewarmConns
(
accountID
,
req
,
need
)
}
func
(
p
*
openAIWSConnPool
)
targetConnCountLocked
(
ap
*
openAIWSAccountPool
,
maxConns
int
)
int
{
if
ap
==
nil
{
return
0
}
if
maxConns
<=
0
{
return
0
}
minIdle
:=
p
.
minIdlePerAccount
()
if
minIdle
<
0
{
minIdle
=
0
}
if
minIdle
>
maxConns
{
minIdle
=
maxConns
}
inflight
,
waiters
:=
accountPoolLoadLocked
(
ap
)
utilization
:=
p
.
targetUtilization
()
demand
:=
inflight
+
waiters
if
demand
<=
0
{
return
minIdle
}
target
:=
1
if
demand
>
1
{
target
=
int
(
math
.
Ceil
(
float64
(
demand
)
/
utilization
))
}
if
waiters
>
0
&&
target
<
len
(
ap
.
conns
)
+
1
{
target
=
len
(
ap
.
conns
)
+
1
}
if
target
<
minIdle
{
target
=
minIdle
}
if
target
>
maxConns
{
target
=
maxConns
}
return
target
}
func
(
p
*
openAIWSConnPool
)
prewarmConns
(
accountID
int64
,
req
openAIWSAcquireRequest
,
total
int
)
{
defer
func
()
{
if
ap
,
ok
:=
p
.
getAccountPool
(
accountID
);
ok
&&
ap
!=
nil
{
ap
.
mu
.
Lock
()
ap
.
prewarmActive
=
false
ap
.
mu
.
Unlock
()
}
}()
for
i
:=
0
;
i
<
total
;
i
++
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
p
.
dialTimeout
()
+
openAIWSConnPrewarmExtraDelay
)
conn
,
err
:=
p
.
dialConn
(
ctx
,
req
)
cancel
()
ap
,
ok
:=
p
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
if
conn
!=
nil
{
conn
.
close
()
}
return
}
ap
.
mu
.
Lock
()
if
ap
.
creating
>
0
{
ap
.
creating
--
}
if
err
!=
nil
{
ap
.
prewarmFails
++
ap
.
prewarmFailAt
=
time
.
Now
()
ap
.
mu
.
Unlock
()
continue
}
if
len
(
ap
.
conns
)
>=
p
.
effectiveMaxConnsByAccount
(
req
.
Account
)
{
ap
.
mu
.
Unlock
()
conn
.
close
()
continue
}
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
prewarmFails
=
0
ap
.
prewarmFailAt
=
time
.
Time
{}
ap
.
mu
.
Unlock
()
}
}
func
(
p
*
openAIWSConnPool
)
evictConn
(
accountID
int64
,
connID
string
)
{
if
p
==
nil
||
accountID
<=
0
||
stringsTrim
(
connID
)
==
""
{
return
}
var
conn
*
openAIWSConn
ap
,
ok
:=
p
.
getAccountPool
(
accountID
)
if
ok
&&
ap
!=
nil
{
ap
.
mu
.
Lock
()
if
c
,
exists
:=
ap
.
conns
[
connID
];
exists
{
conn
=
c
delete
(
ap
.
conns
,
connID
)
if
len
(
ap
.
pinnedConns
)
>
0
{
delete
(
ap
.
pinnedConns
,
connID
)
}
}
ap
.
mu
.
Unlock
()
}
if
conn
!=
nil
{
conn
.
close
()
}
}
func
(
p
*
openAIWSConnPool
)
PinConn
(
accountID
int64
,
connID
string
)
bool
{
if
p
==
nil
||
accountID
<=
0
{
return
false
}
connID
=
stringsTrim
(
connID
)
if
connID
==
""
{
return
false
}
ap
,
ok
:=
p
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
false
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
if
_
,
exists
:=
ap
.
conns
[
connID
];
!
exists
{
return
false
}
if
ap
.
pinnedConns
==
nil
{
ap
.
pinnedConns
=
make
(
map
[
string
]
int
)
}
ap
.
pinnedConns
[
connID
]
++
return
true
}
func
(
p
*
openAIWSConnPool
)
UnpinConn
(
accountID
int64
,
connID
string
)
{
if
p
==
nil
||
accountID
<=
0
{
return
}
connID
=
stringsTrim
(
connID
)
if
connID
==
""
{
return
}
ap
,
ok
:=
p
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
if
len
(
ap
.
pinnedConns
)
==
0
{
return
}
count
:=
ap
.
pinnedConns
[
connID
]
if
count
<=
1
{
delete
(
ap
.
pinnedConns
,
connID
)
return
}
ap
.
pinnedConns
[
connID
]
=
count
-
1
}
func
(
p
*
openAIWSConnPool
)
dialConn
(
ctx
context
.
Context
,
req
openAIWSAcquireRequest
)
(
*
openAIWSConn
,
error
)
{
if
p
==
nil
||
p
.
clientDialer
==
nil
{
return
nil
,
errors
.
New
(
"openai ws client dialer is nil"
)
}
conn
,
status
,
handshakeHeaders
,
err
:=
p
.
clientDialer
.
Dial
(
ctx
,
req
.
WSURL
,
req
.
Headers
,
req
.
ProxyURL
)
if
err
!=
nil
{
return
nil
,
&
openAIWSDialError
{
StatusCode
:
status
,
ResponseHeaders
:
cloneHeader
(
handshakeHeaders
),
Err
:
err
,
}
}
if
conn
==
nil
{
return
nil
,
&
openAIWSDialError
{
StatusCode
:
status
,
ResponseHeaders
:
cloneHeader
(
handshakeHeaders
),
Err
:
errors
.
New
(
"openai ws dialer returned nil connection"
),
}
}
id
:=
p
.
nextConnID
(
req
.
Account
.
ID
)
return
newOpenAIWSConn
(
id
,
req
.
Account
.
ID
,
conn
,
handshakeHeaders
),
nil
}
func
(
p
*
openAIWSConnPool
)
nextConnID
(
accountID
int64
)
string
{
seq
:=
p
.
seq
.
Add
(
1
)
buf
:=
make
([]
byte
,
0
,
32
)
buf
=
append
(
buf
,
"oa_ws_"
...
)
buf
=
strconv
.
AppendInt
(
buf
,
accountID
,
10
)
buf
=
append
(
buf
,
'_'
)
buf
=
strconv
.
AppendUint
(
buf
,
seq
,
10
)
return
string
(
buf
)
}
func
(
p
*
openAIWSConnPool
)
shouldHealthCheckConn
(
conn
*
openAIWSConn
)
bool
{
if
conn
==
nil
{
return
false
}
return
conn
.
idleDuration
(
time
.
Now
())
>=
openAIWSConnHealthCheckIdle
}
func
(
p
*
openAIWSConnPool
)
maxConnsHardCap
()
int
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
&&
p
.
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
>
0
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
}
return
8
}
func
(
p
*
openAIWSConnPool
)
dynamicMaxConnsEnabled
()
bool
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
DynamicMaxConnsByAccountConcurrencyEnabled
}
return
false
}
func
(
p
*
openAIWSConnPool
)
modeRouterV2Enabled
()
bool
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
}
return
false
}
func
(
p
*
openAIWSConnPool
)
maxConnsFactorByAccount
(
account
*
Account
)
float64
{
if
p
==
nil
||
p
.
cfg
==
nil
||
account
==
nil
{
return
1.0
}
switch
account
.
Type
{
case
AccountTypeOAuth
:
if
p
.
cfg
.
Gateway
.
OpenAIWS
.
OAuthMaxConnsFactor
>
0
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
OAuthMaxConnsFactor
}
case
AccountTypeAPIKey
:
if
p
.
cfg
.
Gateway
.
OpenAIWS
.
APIKeyMaxConnsFactor
>
0
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
APIKeyMaxConnsFactor
}
}
return
1.0
}
func
(
p
*
openAIWSConnPool
)
effectiveMaxConnsByAccount
(
account
*
Account
)
int
{
hardCap
:=
p
.
maxConnsHardCap
()
if
hardCap
<=
0
{
return
0
}
if
p
.
modeRouterV2Enabled
()
{
if
account
==
nil
{
return
hardCap
}
if
account
.
Concurrency
<=
0
{
return
0
}
return
account
.
Concurrency
}
if
account
==
nil
||
!
p
.
dynamicMaxConnsEnabled
()
{
return
hardCap
}
if
account
.
Concurrency
<=
0
{
// 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
return
hardCap
}
factor
:=
p
.
maxConnsFactorByAccount
(
account
)
if
factor
<=
0
{
factor
=
1.0
}
effective
:=
int
(
math
.
Ceil
(
float64
(
account
.
Concurrency
)
*
factor
))
if
effective
<
1
{
effective
=
1
}
if
effective
>
hardCap
{
effective
=
hardCap
}
return
effective
}
func
(
p
*
openAIWSConnPool
)
minIdlePerAccount
()
int
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
&&
p
.
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
>=
0
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
}
return
0
}
func
(
p
*
openAIWSConnPool
)
maxIdlePerAccount
()
int
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
&&
p
.
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
>=
0
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
}
return
4
}
func
(
p
*
openAIWSConnPool
)
maxConnAge
()
time
.
Duration
{
return
openAIWSConnMaxAge
}
func
(
p
*
openAIWSConnPool
)
queueLimitPerConn
()
int
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
&&
p
.
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
>
0
{
return
p
.
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
}
return
256
}
func
(
p
*
openAIWSConnPool
)
targetUtilization
()
float64
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
{
ratio
:=
p
.
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
if
ratio
>
0
&&
ratio
<=
1
{
return
ratio
}
}
return
0.7
}
func
(
p
*
openAIWSConnPool
)
prewarmCooldown
()
time
.
Duration
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
&&
p
.
cfg
.
Gateway
.
OpenAIWS
.
PrewarmCooldownMS
>
0
{
return
time
.
Duration
(
p
.
cfg
.
Gateway
.
OpenAIWS
.
PrewarmCooldownMS
)
*
time
.
Millisecond
}
return
0
}
func
(
p
*
openAIWSConnPool
)
shouldSuppressPrewarmLocked
(
ap
*
openAIWSAccountPool
,
now
time
.
Time
)
bool
{
if
ap
==
nil
{
return
true
}
if
ap
.
prewarmFails
<=
0
{
return
false
}
if
ap
.
prewarmFailAt
.
IsZero
()
{
ap
.
prewarmFails
=
0
return
false
}
if
now
.
Sub
(
ap
.
prewarmFailAt
)
>
openAIWSPrewarmFailureWindow
{
ap
.
prewarmFails
=
0
ap
.
prewarmFailAt
=
time
.
Time
{}
return
false
}
return
ap
.
prewarmFails
>=
openAIWSPrewarmFailureSuppress
}
func
(
p
*
openAIWSConnPool
)
dialTimeout
()
time
.
Duration
{
if
p
!=
nil
&&
p
.
cfg
!=
nil
&&
p
.
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
>
0
{
return
time
.
Duration
(
p
.
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
)
*
time
.
Second
}
return
10
*
time
.
Second
}
func
cloneOpenAIWSAcquireRequest
(
req
openAIWSAcquireRequest
)
openAIWSAcquireRequest
{
copied
:=
req
copied
.
Headers
=
cloneHeader
(
req
.
Headers
)
copied
.
WSURL
=
stringsTrim
(
req
.
WSURL
)
copied
.
ProxyURL
=
stringsTrim
(
req
.
ProxyURL
)
copied
.
PreferredConnID
=
stringsTrim
(
req
.
PreferredConnID
)
return
copied
}
func
cloneOpenAIWSAcquireRequestPtr
(
req
*
openAIWSAcquireRequest
)
*
openAIWSAcquireRequest
{
if
req
==
nil
{
return
nil
}
copied
:=
cloneOpenAIWSAcquireRequest
(
*
req
)
return
&
copied
}
func
cloneHeader
(
src
http
.
Header
)
http
.
Header
{
if
src
==
nil
{
return
nil
}
dst
:=
make
(
http
.
Header
,
len
(
src
))
for
k
,
vals
:=
range
src
{
if
len
(
vals
)
==
0
{
dst
[
k
]
=
nil
continue
}
copied
:=
make
([]
string
,
len
(
vals
))
copy
(
copied
,
vals
)
dst
[
k
]
=
copied
}
return
dst
}
func
closeOpenAIWSConns
(
conns
[]
*
openAIWSConn
)
{
if
len
(
conns
)
==
0
{
return
}
for
_
,
conn
:=
range
conns
{
if
conn
==
nil
{
continue
}
conn
.
close
()
}
}
func
stringsTrim
(
value
string
)
string
{
return
strings
.
TrimSpace
(
value
)
}
backend/internal/service/openai_ws_pool_benchmark_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func
BenchmarkOpenAIWSPoolAcquire
(
b
*
testing
.
B
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
8
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
256
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
1
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
&
openAIWSCountingDialer
{})
account
:=
&
Account
{
ID
:
1001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
req
:=
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
}
ctx
:=
context
.
Background
()
lease
,
err
:=
pool
.
Acquire
(
ctx
,
req
)
if
err
!=
nil
{
b
.
Fatalf
(
"warm acquire failed: %v"
,
err
)
}
lease
.
Release
()
b
.
ReportAllocs
()
b
.
ResetTimer
()
b
.
RunParallel
(
func
(
pb
*
testing
.
PB
)
{
for
pb
.
Next
()
{
var
(
got
*
openAIWSConnLease
acquireErr
error
)
for
retry
:=
0
;
retry
<
3
;
retry
++
{
got
,
acquireErr
=
pool
.
Acquire
(
ctx
,
req
)
if
acquireErr
==
nil
{
break
}
if
!
errors
.
Is
(
acquireErr
,
errOpenAIWSConnClosed
)
{
break
}
}
if
acquireErr
!=
nil
{
b
.
Fatalf
(
"acquire failed: %v"
,
acquireErr
)
}
got
.
Release
()
}
})
}
backend/internal/service/openai_ws_pool_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"errors"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestOpenAIWSConnPool_CleanupStaleAndTrimIdle
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
1
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
10
)
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
stale
:=
newOpenAIWSConn
(
"stale"
,
accountID
,
nil
,
nil
)
stale
.
createdAtNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
stale
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
idleOld
:=
newOpenAIWSConn
(
"idle_old"
,
accountID
,
nil
,
nil
)
idleOld
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
10
*
time
.
Minute
)
.
UnixNano
())
idleNew
:=
newOpenAIWSConn
(
"idle_new"
,
accountID
,
nil
,
nil
)
idleNew
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
1
*
time
.
Minute
)
.
UnixNano
())
ap
.
conns
[
stale
.
id
]
=
stale
ap
.
conns
[
idleOld
.
id
]
=
idleOld
ap
.
conns
[
idleNew
.
id
]
=
idleNew
evicted
:=
pool
.
cleanupAccountLocked
(
ap
,
time
.
Now
(),
pool
.
maxConnsHardCap
())
closeOpenAIWSConns
(
evicted
)
require
.
Nil
(
t
,
ap
.
conns
[
"stale"
],
"stale connection should be rotated"
)
require
.
Nil
(
t
,
ap
.
conns
[
"idle_old"
],
"old idle should be trimmed by max_idle"
)
require
.
NotNil
(
t
,
ap
.
conns
[
"idle_new"
],
"newer idle should be kept"
)
}
func
TestOpenAIWSConnPool_NextConnIDFormat
(
t
*
testing
.
T
)
{
pool
:=
newOpenAIWSConnPool
(
&
config
.
Config
{})
id1
:=
pool
.
nextConnID
(
42
)
id2
:=
pool
.
nextConnID
(
42
)
require
.
True
(
t
,
strings
.
HasPrefix
(
id1
,
"oa_ws_42_"
))
require
.
True
(
t
,
strings
.
HasPrefix
(
id2
,
"oa_ws_42_"
))
require
.
NotEqual
(
t
,
id1
,
id2
)
require
.
Equal
(
t
,
"oa_ws_42_1"
,
id1
)
require
.
Equal
(
t
,
"oa_ws_42_2"
,
id2
)
}
func
TestOpenAIWSConnPool_AcquireCleanupInterval
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
3
*
time
.
Second
,
openAIWSAcquireCleanupInterval
)
require
.
Less
(
t
,
openAIWSAcquireCleanupInterval
,
openAIWSBackgroundSweepTicker
)
}
func
TestOpenAIWSConnLease_WriteJSONAndGuards
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"lease_write"
,
1
,
&
openAIWSFakeConn
{},
nil
)
lease
:=
&
openAIWSConnLease
{
conn
:
conn
}
require
.
NoError
(
t
,
lease
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.create"
},
0
))
var
nilLease
*
openAIWSConnLease
err
:=
nilLease
.
WriteJSONWithContextTimeout
(
context
.
Background
(),
map
[
string
]
any
{
"type"
:
"response.create"
},
time
.
Second
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
err
=
(
&
openAIWSConnLease
{})
.
WriteJSONWithContextTimeout
(
context
.
Background
(),
map
[
string
]
any
{
"type"
:
"response.create"
},
time
.
Second
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
}
func
TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground
(
t
*
testing
.
T
)
{
probe
:=
&
openAIWSContextProbeConn
{}
conn
:=
newOpenAIWSConn
(
"ctx_probe"
,
1
,
probe
,
nil
)
require
.
NoError
(
t
,
conn
.
writeJSONWithTimeout
(
context
.
Background
(),
map
[
string
]
any
{
"type"
:
"response.create"
},
0
))
require
.
NotNil
(
t
,
probe
.
lastWriteCtx
)
}
func
TestOpenAIWSConnPool_TargetConnCountAdaptive
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
6
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
=
0.5
pool
:=
newOpenAIWSConnPool
(
cfg
)
ap
:=
pool
.
getOrCreateAccountPool
(
88
)
conn1
:=
newOpenAIWSConn
(
"c1"
,
88
,
nil
,
nil
)
conn2
:=
newOpenAIWSConn
(
"c2"
,
88
,
nil
,
nil
)
require
.
True
(
t
,
conn1
.
tryAcquire
())
require
.
True
(
t
,
conn2
.
tryAcquire
())
conn1
.
waiters
.
Store
(
1
)
conn2
.
waiters
.
Store
(
1
)
ap
.
conns
[
conn1
.
id
]
=
conn1
ap
.
conns
[
conn2
.
id
]
=
conn2
target
:=
pool
.
targetConnCountLocked
(
ap
,
pool
.
maxConnsHardCap
())
require
.
Equal
(
t
,
6
,
target
,
"应按 inflight+waiters 与 target_utilization 自适应扩容到上限"
)
conn1
.
release
()
conn2
.
release
()
conn1
.
waiters
.
Store
(
0
)
conn2
.
waiters
.
Store
(
0
)
target
=
pool
.
targetConnCountLocked
(
ap
,
pool
.
maxConnsHardCap
())
require
.
Equal
(
t
,
1
,
target
,
"低负载时应缩回到最小空闲连接"
)
}
func
TestOpenAIWSConnPool_TargetConnCountMinIdleZero
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
=
0.8
pool
:=
newOpenAIWSConnPool
(
cfg
)
ap
:=
pool
.
getOrCreateAccountPool
(
66
)
target
:=
pool
.
targetConnCountLocked
(
ap
,
pool
.
maxConnsHardCap
())
require
.
Equal
(
t
,
0
,
target
,
"min_idle=0 且无负载时应允许缩容到 0"
)
}
func
TestOpenAIWSConnPool_EnsureTargetIdleAsync
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
=
0.8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
1
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
&
openAIWSFakeDialer
{})
accountID
:=
int64
(
77
)
account
:=
&
Account
{
ID
:
accountID
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
lastAcquire
=
&
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
}
ap
.
mu
.
Unlock
()
pool
.
ensureTargetIdleAsync
(
accountID
)
require
.
Eventually
(
t
,
func
()
bool
{
ap
,
ok
:=
pool
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
false
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
return
len
(
ap
.
conns
)
>=
2
},
2
*
time
.
Second
,
20
*
time
.
Millisecond
)
metrics
:=
pool
.
SnapshotMetrics
()
require
.
GreaterOrEqual
(
t
,
metrics
.
ScaleUpTotal
,
int64
(
2
))
}
func
TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
=
0.8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
1
cfg
.
Gateway
.
OpenAIWS
.
PrewarmCooldownMS
=
500
pool
:=
newOpenAIWSConnPool
(
cfg
)
dialer
:=
&
openAIWSCountingDialer
{}
pool
.
setClientDialerForTest
(
dialer
)
accountID
:=
int64
(
178
)
account
:=
&
Account
{
ID
:
accountID
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
lastAcquire
=
&
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
}
ap
.
mu
.
Unlock
()
pool
.
ensureTargetIdleAsync
(
accountID
)
require
.
Eventually
(
t
,
func
()
bool
{
ap
,
ok
:=
pool
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
false
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
return
len
(
ap
.
conns
)
>=
2
&&
!
ap
.
prewarmActive
},
2
*
time
.
Second
,
20
*
time
.
Millisecond
)
firstDialCount
:=
dialer
.
DialCount
()
require
.
GreaterOrEqual
(
t
,
firstDialCount
,
2
)
// 人工制造缺口触发新一轮预热需求。
ap
,
ok
:=
pool
.
getAccountPool
(
accountID
)
require
.
True
(
t
,
ok
)
require
.
NotNil
(
t
,
ap
)
ap
.
mu
.
Lock
()
for
id
:=
range
ap
.
conns
{
delete
(
ap
.
conns
,
id
)
break
}
ap
.
mu
.
Unlock
()
pool
.
ensureTargetIdleAsync
(
accountID
)
time
.
Sleep
(
120
*
time
.
Millisecond
)
require
.
Equal
(
t
,
firstDialCount
,
dialer
.
DialCount
(),
"cooldown 窗口内不应再次触发预热"
)
time
.
Sleep
(
450
*
time
.
Millisecond
)
pool
.
ensureTargetIdleAsync
(
accountID
)
require
.
Eventually
(
t
,
func
()
bool
{
return
dialer
.
DialCount
()
>
firstDialCount
},
2
*
time
.
Second
,
20
*
time
.
Millisecond
)
}
func
TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
=
0.8
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
1
cfg
.
Gateway
.
OpenAIWS
.
PrewarmCooldownMS
=
0
pool
:=
newOpenAIWSConnPool
(
cfg
)
dialer
:=
&
openAIWSAlwaysFailDialer
{}
pool
.
setClientDialerForTest
(
dialer
)
accountID
:=
int64
(
279
)
account
:=
&
Account
{
ID
:
accountID
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
lastAcquire
=
&
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
}
ap
.
mu
.
Unlock
()
pool
.
ensureTargetIdleAsync
(
accountID
)
require
.
Eventually
(
t
,
func
()
bool
{
ap
,
ok
:=
pool
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
false
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
return
!
ap
.
prewarmActive
},
2
*
time
.
Second
,
20
*
time
.
Millisecond
)
pool
.
ensureTargetIdleAsync
(
accountID
)
require
.
Eventually
(
t
,
func
()
bool
{
ap
,
ok
:=
pool
.
getAccountPool
(
accountID
)
if
!
ok
||
ap
==
nil
{
return
false
}
ap
.
mu
.
Lock
()
defer
ap
.
mu
.
Unlock
()
return
!
ap
.
prewarmActive
},
2
*
time
.
Second
,
20
*
time
.
Millisecond
)
require
.
Equal
(
t
,
2
,
dialer
.
DialCount
())
// 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
pool
.
ensureTargetIdleAsync
(
accountID
)
time
.
Sleep
(
120
*
time
.
Millisecond
)
require
.
Equal
(
t
,
2
,
dialer
.
DialCount
())
}
func
TestOpenAIWSConnPool_AcquireQueueWaitMetrics
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
4
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
99
)
account
:=
&
Account
{
ID
:
accountID
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
conn
:=
newOpenAIWSConn
(
"busy"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
conn
.
tryAcquire
())
// 占用连接,触发后续排队
ap
:=
pool
.
ensureAccountPoolLocked
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
lastAcquire
=
&
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
}
ap
.
mu
.
Unlock
()
go
func
()
{
time
.
Sleep
(
60
*
time
.
Millisecond
)
conn
.
release
()
}()
lease
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
lease
)
require
.
True
(
t
,
lease
.
Reused
())
require
.
GreaterOrEqual
(
t
,
lease
.
QueueWaitDuration
(),
50
*
time
.
Millisecond
)
lease
.
Release
()
metrics
:=
pool
.
SnapshotMetrics
()
require
.
GreaterOrEqual
(
t
,
metrics
.
AcquireQueueWaitTotal
,
int64
(
1
))
require
.
Greater
(
t
,
metrics
.
AcquireQueueWaitMsTotal
,
int64
(
0
))
require
.
GreaterOrEqual
(
t
,
metrics
.
ConnPickTotal
,
int64
(
1
))
}
func
TestOpenAIWSConnPool_ForceNewConnSkipsReuse
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
2
pool
:=
newOpenAIWSConnPool
(
cfg
)
dialer
:=
&
openAIWSCountingDialer
{}
pool
.
setClientDialerForTest
(
dialer
)
account
:=
&
Account
{
ID
:
123
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
lease1
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
lease1
)
lease1
.
Release
()
lease2
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
ForceNewConn
:
true
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
lease2
)
lease2
.
Release
()
require
.
Equal
(
t
,
2
,
dialer
.
DialCount
(),
"ForceNewConn=true 时应跳过空闲连接复用并新建连接"
)
}
func
TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
2
pool
:=
newOpenAIWSConnPool
(
cfg
)
account
:=
&
Account
{
ID
:
124
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
pool
.
getOrCreateAccountPool
(
account
.
ID
)
otherConn
:=
newOpenAIWSConn
(
"other_conn"
,
account
.
ID
,
&
openAIWSFakeConn
{},
nil
)
ap
.
mu
.
Lock
()
ap
.
conns
[
otherConn
.
id
]
=
otherConn
ap
.
mu
.
Unlock
()
_
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
ForcePreferredConn
:
true
,
})
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSPreferredConnUnavailable
)
_
,
err
=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
PreferredConnID
:
"missing_conn"
,
ForcePreferredConn
:
true
,
})
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSPreferredConnUnavailable
)
}
func
TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
4
pool
:=
newOpenAIWSConnPool
(
cfg
)
account
:=
&
Account
{
ID
:
125
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
pool
.
getOrCreateAccountPool
(
account
.
ID
)
preferredConn
:=
newOpenAIWSConn
(
"preferred_conn"
,
account
.
ID
,
&
openAIWSFakeConn
{},
nil
)
otherConn
:=
newOpenAIWSConn
(
"other_conn_idle"
,
account
.
ID
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
preferredConn
.
tryAcquire
(),
"先占用 preferred 连接,触发排队获取"
)
ap
.
mu
.
Lock
()
ap
.
conns
[
preferredConn
.
id
]
=
preferredConn
ap
.
conns
[
otherConn
.
id
]
=
otherConn
ap
.
lastCleanupAt
=
time
.
Now
()
ap
.
mu
.
Unlock
()
go
func
()
{
time
.
Sleep
(
60
*
time
.
Millisecond
)
preferredConn
.
release
()
}()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
defer
cancel
()
lease
,
err
:=
pool
.
Acquire
(
ctx
,
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
PreferredConnID
:
preferredConn
.
id
,
ForcePreferredConn
:
true
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
lease
)
require
.
Equal
(
t
,
preferredConn
.
id
,
lease
.
ConnID
(),
"严格模式应只等待并复用 preferred 连接,不可漂移"
)
require
.
GreaterOrEqual
(
t
,
lease
.
QueueWaitDuration
(),
40
*
time
.
Millisecond
)
lease
.
Release
()
require
.
True
(
t
,
otherConn
.
tryAcquire
(),
"other 连接不应被严格模式抢占"
)
otherConn
.
release
()
}
func
TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
1
pool
:=
newOpenAIWSConnPool
(
cfg
)
account
:=
&
Account
{
ID
:
127
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
pool
.
getOrCreateAccountPool
(
account
.
ID
)
preferredConn
:=
newOpenAIWSConn
(
"preferred_conn_direct"
,
account
.
ID
,
&
openAIWSFakeConn
{},
nil
)
otherConn
:=
newOpenAIWSConn
(
"other_conn_direct"
,
account
.
ID
,
&
openAIWSFakeConn
{},
nil
)
ap
.
mu
.
Lock
()
ap
.
conns
[
preferredConn
.
id
]
=
preferredConn
ap
.
conns
[
otherConn
.
id
]
=
otherConn
ap
.
lastCleanupAt
=
time
.
Now
()
ap
.
mu
.
Unlock
()
lease
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
PreferredConnID
:
preferredConn
.
id
,
ForcePreferredConn
:
true
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
preferredConn
.
id
,
lease
.
ConnID
(),
"preferred 空闲时应直接命中"
)
lease
.
Release
()
require
.
True
(
t
,
preferredConn
.
tryAcquire
())
preferredConn
.
waiters
.
Store
(
1
)
_
,
err
=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
PreferredConnID
:
preferredConn
.
id
,
ForcePreferredConn
:
true
,
})
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnQueueFull
,
"严格模式下队列满应直接失败,不得漂移"
)
preferredConn
.
waiters
.
Store
(
0
)
preferredConn
.
release
()
}
func
TestOpenAIWSConnPool_CleanupSkipsPinnedConn
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
0
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
126
)
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
pinnedConn
:=
newOpenAIWSConn
(
"pinned_conn"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
idleConn
:=
newOpenAIWSConn
(
"idle_conn"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
ap
.
mu
.
Lock
()
ap
.
conns
[
pinnedConn
.
id
]
=
pinnedConn
ap
.
conns
[
idleConn
.
id
]
=
idleConn
ap
.
mu
.
Unlock
()
require
.
True
(
t
,
pool
.
PinConn
(
accountID
,
pinnedConn
.
id
))
evicted
:=
pool
.
cleanupAccountLocked
(
ap
,
time
.
Now
(),
pool
.
maxConnsHardCap
())
closeOpenAIWSConns
(
evicted
)
ap
.
mu
.
Lock
()
_
,
pinnedExists
:=
ap
.
conns
[
pinnedConn
.
id
]
_
,
idleExists
:=
ap
.
conns
[
idleConn
.
id
]
ap
.
mu
.
Unlock
()
require
.
True
(
t
,
pinnedExists
,
"被 active ingress 绑定的连接不应被 cleanup 回收"
)
require
.
False
(
t
,
idleExists
,
"非绑定的空闲连接应被回收"
)
pool
.
UnpinConn
(
accountID
,
pinnedConn
.
id
)
evicted
=
pool
.
cleanupAccountLocked
(
ap
,
time
.
Now
(),
pool
.
maxConnsHardCap
())
closeOpenAIWSConns
(
evicted
)
ap
.
mu
.
Lock
()
_
,
pinnedExists
=
ap
.
conns
[
pinnedConn
.
id
]
ap
.
mu
.
Unlock
()
require
.
False
(
t
,
pinnedExists
,
"解绑后连接应可被正常回收"
)
}
func
TestOpenAIWSConnPool_PinUnpinConnBranches
(
t
*
testing
.
T
)
{
var
nilPool
*
openAIWSConnPool
require
.
False
(
t
,
nilPool
.
PinConn
(
1
,
"x"
))
nilPool
.
UnpinConn
(
1
,
"x"
)
cfg
:=
&
config
.
Config
{}
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
128
)
ap
:=
&
openAIWSAccountPool
{
conns
:
map
[
string
]
*
openAIWSConn
{},
}
pool
.
accounts
.
Store
(
accountID
,
ap
)
require
.
False
(
t
,
pool
.
PinConn
(
0
,
"x"
))
require
.
False
(
t
,
pool
.
PinConn
(
999
,
"x"
))
require
.
False
(
t
,
pool
.
PinConn
(
accountID
,
""
))
require
.
False
(
t
,
pool
.
PinConn
(
accountID
,
"missing"
))
conn
:=
newOpenAIWSConn
(
"pin_refcount"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
ap
.
mu
.
Lock
()
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
mu
.
Unlock
()
require
.
True
(
t
,
pool
.
PinConn
(
accountID
,
conn
.
id
))
require
.
True
(
t
,
pool
.
PinConn
(
accountID
,
conn
.
id
))
ap
.
mu
.
Lock
()
require
.
Equal
(
t
,
2
,
ap
.
pinnedConns
[
conn
.
id
])
ap
.
mu
.
Unlock
()
pool
.
UnpinConn
(
accountID
,
conn
.
id
)
ap
.
mu
.
Lock
()
require
.
Equal
(
t
,
1
,
ap
.
pinnedConns
[
conn
.
id
])
ap
.
mu
.
Unlock
()
pool
.
UnpinConn
(
accountID
,
conn
.
id
)
ap
.
mu
.
Lock
()
_
,
exists
:=
ap
.
pinnedConns
[
conn
.
id
]
ap
.
mu
.
Unlock
()
require
.
False
(
t
,
exists
)
pool
.
UnpinConn
(
accountID
,
conn
.
id
)
pool
.
UnpinConn
(
accountID
,
""
)
pool
.
UnpinConn
(
0
,
conn
.
id
)
pool
.
UnpinConn
(
999
,
conn
.
id
)
}
func
TestOpenAIWSConnPool_EffectiveMaxConnsByAccount
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
8
cfg
.
Gateway
.
OpenAIWS
.
DynamicMaxConnsByAccountConcurrencyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthMaxConnsFactor
=
1.0
cfg
.
Gateway
.
OpenAIWS
.
APIKeyMaxConnsFactor
=
0.6
pool
:=
newOpenAIWSConnPool
(
cfg
)
oauthHigh
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
10
}
require
.
Equal
(
t
,
8
,
pool
.
effectiveMaxConnsByAccount
(
oauthHigh
),
"应受全局硬上限约束"
)
oauthLow
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
3
}
require
.
Equal
(
t
,
3
,
pool
.
effectiveMaxConnsByAccount
(
oauthLow
))
apiKeyHigh
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
10
}
require
.
Equal
(
t
,
6
,
pool
.
effectiveMaxConnsByAccount
(
apiKeyHigh
),
"API Key 应按系数缩放"
)
apiKeyLow
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
}
require
.
Equal
(
t
,
1
,
pool
.
effectiveMaxConnsByAccount
(
apiKeyLow
),
"最小值应保持为 1"
)
unlimited
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
0
}
require
.
Equal
(
t
,
8
,
pool
.
effectiveMaxConnsByAccount
(
unlimited
),
"无限并发应回退到全局硬上限"
)
require
.
Equal
(
t
,
8
,
pool
.
effectiveMaxConnsByAccount
(
nil
),
"缺少账号上下文应回退到全局硬上限"
)
}
func
TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
8
cfg
.
Gateway
.
OpenAIWS
.
DynamicMaxConnsByAccountConcurrencyEnabled
=
false
cfg
.
Gateway
.
OpenAIWS
.
OAuthMaxConnsFactor
=
1.0
cfg
.
Gateway
.
OpenAIWS
.
APIKeyMaxConnsFactor
=
1.0
pool
:=
newOpenAIWSConnPool
(
cfg
)
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
2
}
require
.
Equal
(
t
,
8
,
pool
.
effectiveMaxConnsByAccount
(
account
),
"关闭动态模式后应保持旧行为"
)
}
func
TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
8
cfg
.
Gateway
.
OpenAIWS
.
DynamicMaxConnsByAccountConcurrencyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthMaxConnsFactor
=
0.3
cfg
.
Gateway
.
OpenAIWS
.
APIKeyMaxConnsFactor
=
0.6
pool
:=
newOpenAIWSConnPool
(
cfg
)
high
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
20
}
require
.
Equal
(
t
,
20
,
pool
.
effectiveMaxConnsByAccount
(
high
),
"v2 路径应直接使用账号并发数作为池上限"
)
nonPositive
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
0
}
require
.
Equal
(
t
,
0
,
pool
.
effectiveMaxConnsByAccount
(
nonPositive
),
"并发数<=0 时应不可调度"
)
}
func
TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
8
pool
:=
newOpenAIWSConnPool
(
cfg
)
account
:=
&
Account
{
ID
:
901
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
0
}
_
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
})
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnQueueFull
)
}
func
TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"timeout"
,
1
,
&
openAIWSBlockingConn
{
readDelay
:
80
*
time
.
Millisecond
},
nil
)
lease
:=
&
openAIWSConnLease
{
conn
:
conn
}
_
,
err
:=
lease
.
ReadMessageWithContextTimeout
(
context
.
Background
(),
20
*
time
.
Millisecond
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
context
.
DeadlineExceeded
)
payload
,
err
:=
lease
.
ReadMessageWithContextTimeout
(
context
.
Background
(),
150
*
time
.
Millisecond
)
require
.
NoError
(
t
,
err
)
require
.
Contains
(
t
,
string
(
payload
),
"response.completed"
)
parentCtx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
_
,
err
=
lease
.
ReadMessageWithContextTimeout
(
parentCtx
,
150
*
time
.
Millisecond
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
context
.
Canceled
)
}
func
TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"write_timeout_ctx"
,
1
,
&
openAIWSWriteBlockingConn
{},
nil
)
lease
:=
&
openAIWSConnLease
{
conn
:
conn
}
parentCtx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
go
func
()
{
time
.
Sleep
(
20
*
time
.
Millisecond
)
cancel
()
}()
start
:=
time
.
Now
()
err
:=
lease
.
WriteJSONWithContextTimeout
(
parentCtx
,
map
[
string
]
any
{
"type"
:
"response.create"
},
2
*
time
.
Minute
)
elapsed
:=
time
.
Since
(
start
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
context
.
Canceled
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
)
}
func
TestOpenAIWSConnLease_PingWithTimeout
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"ping_ok"
,
1
,
&
openAIWSFakeConn
{},
nil
)
lease
:=
&
openAIWSConnLease
{
conn
:
conn
}
require
.
NoError
(
t
,
lease
.
PingWithTimeout
(
50
*
time
.
Millisecond
))
var
nilLease
*
openAIWSConnLease
err
:=
nilLease
.
PingWithTimeout
(
50
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
}
func
TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"full_duplex"
,
1
,
&
openAIWSBlockingConn
{
readDelay
:
120
*
time
.
Millisecond
},
nil
)
readDone
:=
make
(
chan
error
,
1
)
go
func
()
{
_
,
err
:=
conn
.
readMessageWithContextTimeout
(
context
.
Background
(),
200
*
time
.
Millisecond
)
readDone
<-
err
}()
// 让读取先占用 readMu。
time
.
Sleep
(
20
*
time
.
Millisecond
)
start
:=
time
.
Now
()
err
:=
conn
.
pingWithTimeout
(
50
*
time
.
Millisecond
)
elapsed
:=
time
.
Since
(
start
)
require
.
NoError
(
t
,
err
)
require
.
Less
(
t
,
elapsed
,
80
*
time
.
Millisecond
,
"写路径不应被读锁长期阻塞"
)
require
.
NoError
(
t
,
<-
readDone
)
}
func
TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
301
)
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
conn
:=
newOpenAIWSConn
(
"dead_idle"
,
accountID
,
&
openAIWSPingFailConn
{},
nil
)
ap
.
mu
.
Lock
()
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
mu
.
Unlock
()
pool
.
runBackgroundPingSweep
()
ap
.
mu
.
Lock
()
_
,
exists
:=
ap
.
conns
[
conn
.
id
]
ap
.
mu
.
Unlock
()
require
.
False
(
t
,
exists
,
"后台 ping 失败的空闲连接应被回收"
)
}
func
TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
MaxIdlePerAccount
=
2
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
302
)
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
stale
:=
newOpenAIWSConn
(
"stale_bg"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
stale
.
createdAtNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
stale
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
ap
.
mu
.
Lock
()
ap
.
conns
[
stale
.
id
]
=
stale
ap
.
mu
.
Unlock
()
pool
.
runBackgroundCleanupSweep
(
time
.
Now
())
ap
.
mu
.
Lock
()
_
,
exists
:=
ap
.
conns
[
stale
.
id
]
ap
.
mu
.
Unlock
()
require
.
False
(
t
,
exists
,
"后台清理应在无新 acquire 时也回收过期连接"
)
}
func
TestOpenAIWSConnPool_BackgroundWorkerGuardBranches
(
t
*
testing
.
T
)
{
var
nilPool
*
openAIWSConnPool
require
.
NotPanics
(
t
,
func
()
{
nilPool
.
startBackgroundWorkers
()
nilPool
.
runBackgroundPingWorker
()
nilPool
.
runBackgroundPingSweep
()
_
=
nilPool
.
snapshotIdleConnsForPing
()
nilPool
.
runBackgroundCleanupWorker
()
nilPool
.
runBackgroundCleanupSweep
(
time
.
Now
())
})
poolNoStop
:=
&
openAIWSConnPool
{}
require
.
NotPanics
(
t
,
func
()
{
poolNoStop
.
startBackgroundWorkers
()
})
poolStopPing
:=
&
openAIWSConnPool
{
workerStopCh
:
make
(
chan
struct
{})}
pingDone
:=
make
(
chan
struct
{})
go
func
()
{
poolStopPing
.
runBackgroundPingWorker
()
close
(
pingDone
)
}()
close
(
poolStopPing
.
workerStopCh
)
select
{
case
<-
pingDone
:
case
<-
time
.
After
(
500
*
time
.
Millisecond
)
:
t
.
Fatal
(
"runBackgroundPingWorker 未在 stop 信号后退出"
)
}
poolStopCleanup
:=
&
openAIWSConnPool
{
workerStopCh
:
make
(
chan
struct
{})}
cleanupDone
:=
make
(
chan
struct
{})
go
func
()
{
poolStopCleanup
.
runBackgroundCleanupWorker
()
close
(
cleanupDone
)
}()
close
(
poolStopCleanup
.
workerStopCh
)
select
{
case
<-
cleanupDone
:
case
<-
time
.
After
(
500
*
time
.
Millisecond
)
:
t
.
Fatal
(
"runBackgroundCleanupWorker 未在 stop 信号后退出"
)
}
}
func
TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries
(
t
*
testing
.
T
)
{
pool
:=
&
openAIWSConnPool
{}
pool
.
accounts
.
Store
(
"invalid-key"
,
&
openAIWSAccountPool
{})
pool
.
accounts
.
Store
(
int64
(
123
),
"invalid-value"
)
accountID
:=
int64
(
123
)
ap
:=
&
openAIWSAccountPool
{
conns
:
make
(
map
[
string
]
*
openAIWSConn
),
}
ap
.
conns
[
"nil_conn"
]
=
nil
leased
:=
newOpenAIWSConn
(
"leased"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
leased
.
tryAcquire
())
ap
.
conns
[
leased
.
id
]
=
leased
waiting
:=
newOpenAIWSConn
(
"waiting"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
waiting
.
waiters
.
Store
(
1
)
ap
.
conns
[
waiting
.
id
]
=
waiting
idle
:=
newOpenAIWSConn
(
"idle"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
ap
.
conns
[
idle
.
id
]
=
idle
pool
.
accounts
.
Store
(
accountID
,
ap
)
candidates
:=
pool
.
snapshotIdleConnsForPing
()
require
.
Len
(
t
,
candidates
,
1
)
require
.
Equal
(
t
,
idle
.
id
,
candidates
[
0
]
.
conn
.
id
)
}
func
TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
4
cfg
.
Gateway
.
OpenAIWS
.
DynamicMaxConnsByAccountConcurrencyEnabled
=
true
pool
:=
&
openAIWSConnPool
{
cfg
:
cfg
}
pool
.
accounts
.
Store
(
"bad-key"
,
"bad-value"
)
accountID
:=
int64
(
2026
)
ap
:=
&
openAIWSAccountPool
{
conns
:
make
(
map
[
string
]
*
openAIWSConn
),
}
ap
.
conns
[
"nil_conn"
]
=
nil
stale
:=
newOpenAIWSConn
(
"stale_bg_cleanup"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
stale
.
createdAtNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
stale
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
ap
.
conns
[
stale
.
id
]
=
stale
ap
.
lastAcquire
=
&
openAIWSAcquireRequest
{
Account
:
&
Account
{
ID
:
accountID
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
},
}
pool
.
accounts
.
Store
(
accountID
,
ap
)
now
:=
time
.
Now
()
require
.
NotPanics
(
t
,
func
()
{
pool
.
runBackgroundCleanupSweep
(
now
)
})
ap
.
mu
.
Lock
()
_
,
nilConnExists
:=
ap
.
conns
[
"nil_conn"
]
_
,
exists
:=
ap
.
conns
[
stale
.
id
]
lastCleanupAt
:=
ap
.
lastCleanupAt
ap
.
mu
.
Unlock
()
require
.
False
(
t
,
nilConnExists
,
"后台清理应移除无效 nil 连接条目"
)
require
.
False
(
t
,
exists
,
"后台清理应清理过期连接"
)
require
.
Equal
(
t
,
now
,
lastCleanupAt
)
}
func
TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured
(
t
*
testing
.
T
)
{
var
nilPool
*
openAIWSConnPool
require
.
Equal
(
t
,
256
,
nilPool
.
queueLimitPerConn
())
pool
:=
&
openAIWSConnPool
{
cfg
:
&
config
.
Config
{}}
require
.
Equal
(
t
,
256
,
pool
.
queueLimitPerConn
())
pool
.
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
9
require
.
Equal
(
t
,
9
,
pool
.
queueLimitPerConn
())
}
func
TestOpenAIWSConnPool_Close
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
pool
:=
newOpenAIWSConnPool
(
cfg
)
// Close 应该可以安全调用
pool
.
Close
()
// workerStopCh 应已关闭
select
{
case
<-
pool
.
workerStopCh
:
// 预期:channel 已关闭
default
:
t
.
Fatal
(
"Close 后 workerStopCh 应已关闭"
)
}
// 多次调用 Close 不应 panic
pool
.
Close
()
// nil pool 调用 Close 不应 panic
var
nilPool
*
openAIWSConnPool
nilPool
.
Close
()
}
func
TestOpenAIWSDialError_ErrorAndUnwrap
(
t
*
testing
.
T
)
{
baseErr
:=
errors
.
New
(
"boom"
)
dialErr
:=
&
openAIWSDialError
{
StatusCode
:
502
,
Err
:
baseErr
}
require
.
Contains
(
t
,
dialErr
.
Error
(),
"status=502"
)
require
.
ErrorIs
(
t
,
dialErr
.
Unwrap
(),
baseErr
)
noStatus
:=
&
openAIWSDialError
{
Err
:
baseErr
}
require
.
Contains
(
t
,
noStatus
.
Error
(),
"boom"
)
var
nilDialErr
*
openAIWSDialError
require
.
Equal
(
t
,
""
,
nilDialErr
.
Error
())
require
.
NoError
(
t
,
nilDialErr
.
Unwrap
())
}
func
TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"helper_conn"
,
1
,
&
openAIWSFakeConn
{},
http
.
Header
{
"X-Test"
:
[]
string
{
" value "
},
})
lease
:=
&
openAIWSConnLease
{
conn
:
conn
}
require
.
NoError
(
t
,
lease
.
WriteJSONContext
(
context
.
Background
(),
map
[
string
]
any
{
"type"
:
"response.create"
}))
payload
,
err
:=
lease
.
ReadMessage
(
100
*
time
.
Millisecond
)
require
.
NoError
(
t
,
err
)
require
.
Contains
(
t
,
string
(
payload
),
"response.completed"
)
payload
,
err
=
lease
.
ReadMessageContext
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Contains
(
t
,
string
(
payload
),
"response.completed"
)
payload
,
err
=
conn
.
readMessageWithTimeout
(
100
*
time
.
Millisecond
)
require
.
NoError
(
t
,
err
)
require
.
Contains
(
t
,
string
(
payload
),
"response.completed"
)
require
.
Equal
(
t
,
"value"
,
conn
.
handshakeHeader
(
" X-Test "
))
require
.
NotZero
(
t
,
conn
.
createdAt
())
require
.
NotZero
(
t
,
conn
.
lastUsedAt
())
require
.
GreaterOrEqual
(
t
,
conn
.
age
(
time
.
Now
()),
time
.
Duration
(
0
))
require
.
GreaterOrEqual
(
t
,
conn
.
idleDuration
(
time
.
Now
()),
time
.
Duration
(
0
))
require
.
False
(
t
,
conn
.
isLeased
())
// 覆盖空上下文路径
_
,
err
=
conn
.
readMessage
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
// 覆盖 nil 保护分支
var
nilConn
*
openAIWSConn
require
.
ErrorIs
(
t
,
nilConn
.
writeJSONWithTimeout
(
context
.
Background
(),
map
[
string
]
any
{},
time
.
Second
),
errOpenAIWSConnClosed
)
_
,
err
=
nilConn
.
readMessageWithTimeout
(
10
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
_
,
err
=
nilConn
.
readMessageWithContextTimeout
(
context
.
Background
(),
10
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
}
func
TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad
(
t
*
testing
.
T
)
{
pool
:=
&
openAIWSConnPool
{}
accountID
:=
int64
(
404
)
ap
:=
&
openAIWSAccountPool
{
conns
:
map
[
string
]
*
openAIWSConn
{}}
idleOld
:=
newOpenAIWSConn
(
"idle_old"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
idleOld
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
10
*
time
.
Minute
)
.
UnixNano
())
idleNew
:=
newOpenAIWSConn
(
"idle_new"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
idleNew
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
1
*
time
.
Minute
)
.
UnixNano
())
leased
:=
newOpenAIWSConn
(
"leased"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
leased
.
tryAcquire
())
leased
.
waiters
.
Store
(
2
)
ap
.
conns
[
idleOld
.
id
]
=
idleOld
ap
.
conns
[
idleNew
.
id
]
=
idleNew
ap
.
conns
[
leased
.
id
]
=
leased
oldest
:=
pool
.
pickOldestIdleConnLocked
(
ap
)
require
.
NotNil
(
t
,
oldest
)
require
.
Equal
(
t
,
idleOld
.
id
,
oldest
.
id
)
inflight
,
waiters
:=
accountPoolLoadLocked
(
ap
)
require
.
Equal
(
t
,
1
,
inflight
)
require
.
Equal
(
t
,
2
,
waiters
)
pool
.
accounts
.
Store
(
accountID
,
ap
)
loadInflight
,
loadWaiters
,
conns
:=
pool
.
AccountPoolLoad
(
accountID
)
require
.
Equal
(
t
,
1
,
loadInflight
)
require
.
Equal
(
t
,
2
,
loadWaiters
)
require
.
Equal
(
t
,
3
,
conns
)
zeroInflight
,
zeroWaiters
,
zeroConns
:=
pool
.
AccountPoolLoad
(
0
)
require
.
Equal
(
t
,
0
,
zeroInflight
)
require
.
Equal
(
t
,
0
,
zeroWaiters
)
require
.
Equal
(
t
,
0
,
zeroConns
)
}
func
TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel
(
t
*
testing
.
T
)
{
pool
:=
&
openAIWSConnPool
{}
release
:=
make
(
chan
struct
{})
pool
.
workerWg
.
Add
(
1
)
go
func
()
{
defer
pool
.
workerWg
.
Done
()
<-
release
}()
closed
:=
make
(
chan
struct
{})
go
func
()
{
pool
.
Close
()
close
(
closed
)
}()
select
{
case
<-
closed
:
t
.
Fatal
(
"Close 不应在 WaitGroup 未完成时提前返回"
)
case
<-
time
.
After
(
30
*
time
.
Millisecond
)
:
}
close
(
release
)
select
{
case
<-
closed
:
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"Close 未等待 workerWg 完成"
)
}
}
func
TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections
(
t
*
testing
.
T
)
{
pool
:=
&
openAIWSConnPool
{
workerStopCh
:
make
(
chan
struct
{}),
}
accountID
:=
int64
(
606
)
ap
:=
&
openAIWSAccountPool
{
conns
:
map
[
string
]
*
openAIWSConn
{},
}
idle
:=
newOpenAIWSConn
(
"idle_conn"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
leased
:=
newOpenAIWSConn
(
"leased_conn"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
leased
.
tryAcquire
())
ap
.
conns
[
idle
.
id
]
=
idle
ap
.
conns
[
leased
.
id
]
=
leased
pool
.
accounts
.
Store
(
accountID
,
ap
)
pool
.
accounts
.
Store
(
"invalid-key"
,
"invalid-value"
)
pool
.
Close
()
select
{
case
<-
idle
.
closedCh
:
// idle should be closed
default
:
t
.
Fatal
(
"空闲连接应在 Close 时被关闭"
)
}
select
{
case
<-
leased
.
closedCh
:
t
.
Fatal
(
"已租赁连接不应在 Close 时被关闭"
)
default
:
}
leased
.
release
()
pool
.
Close
()
}
func
TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
pool
:=
newOpenAIWSConnPool
(
cfg
)
accountID
:=
int64
(
505
)
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
var
current
atomic
.
Int32
var
maxConcurrent
atomic
.
Int32
release
:=
make
(
chan
struct
{})
for
i
:=
0
;
i
<
25
;
i
++
{
conn
:=
newOpenAIWSConn
(
pool
.
nextConnID
(
accountID
),
accountID
,
&
openAIWSPingBlockingConn
{
current
:
&
current
,
maxConcurrent
:
&
maxConcurrent
,
release
:
release
,
},
nil
)
ap
.
mu
.
Lock
()
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
mu
.
Unlock
()
}
done
:=
make
(
chan
struct
{})
go
func
()
{
pool
.
runBackgroundPingSweep
()
close
(
done
)
}()
require
.
Eventually
(
t
,
func
()
bool
{
return
maxConcurrent
.
Load
()
>=
10
},
time
.
Second
,
10
*
time
.
Millisecond
)
close
(
release
)
select
{
case
<-
done
:
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"runBackgroundPingSweep 未在释放后完成"
)
}
require
.
LessOrEqual
(
t
,
maxConcurrent
.
Load
(),
int32
(
10
))
}
func
TestOpenAIWSConnLease_BasicGetterBranches
(
t
*
testing
.
T
)
{
var
nilLease
*
openAIWSConnLease
require
.
Equal
(
t
,
""
,
nilLease
.
ConnID
())
require
.
Equal
(
t
,
time
.
Duration
(
0
),
nilLease
.
QueueWaitDuration
())
require
.
Equal
(
t
,
time
.
Duration
(
0
),
nilLease
.
ConnPickDuration
())
require
.
False
(
t
,
nilLease
.
Reused
())
require
.
Equal
(
t
,
""
,
nilLease
.
HandshakeHeader
(
"x-test"
))
require
.
False
(
t
,
nilLease
.
IsPrewarmed
())
nilLease
.
MarkPrewarmed
()
nilLease
.
Release
()
conn
:=
newOpenAIWSConn
(
"getter_conn"
,
1
,
&
openAIWSFakeConn
{},
http
.
Header
{
"X-Test"
:
[]
string
{
"ok"
}})
lease
:=
&
openAIWSConnLease
{
conn
:
conn
,
queueWait
:
3
*
time
.
Millisecond
,
connPick
:
4
*
time
.
Millisecond
,
reused
:
true
,
}
require
.
Equal
(
t
,
"getter_conn"
,
lease
.
ConnID
())
require
.
Equal
(
t
,
3
*
time
.
Millisecond
,
lease
.
QueueWaitDuration
())
require
.
Equal
(
t
,
4
*
time
.
Millisecond
,
lease
.
ConnPickDuration
())
require
.
True
(
t
,
lease
.
Reused
())
require
.
Equal
(
t
,
"ok"
,
lease
.
HandshakeHeader
(
"x-test"
))
require
.
False
(
t
,
lease
.
IsPrewarmed
())
lease
.
MarkPrewarmed
()
require
.
True
(
t
,
lease
.
IsPrewarmed
())
lease
.
Release
()
}
func
TestOpenAIWSConnPool_UtilityBranches
(
t
*
testing
.
T
)
{
var
nilPool
*
openAIWSConnPool
require
.
Equal
(
t
,
OpenAIWSPoolMetricsSnapshot
{},
nilPool
.
SnapshotMetrics
())
require
.
Equal
(
t
,
OpenAIWSTransportMetricsSnapshot
{},
nilPool
.
SnapshotTransportMetrics
())
pool
:=
&
openAIWSConnPool
{
cfg
:
&
config
.
Config
{}}
pool
.
metrics
.
acquireTotal
.
Store
(
7
)
pool
.
metrics
.
acquireReuseTotal
.
Store
(
3
)
metrics
:=
pool
.
SnapshotMetrics
()
require
.
Equal
(
t
,
int64
(
7
),
metrics
.
AcquireTotal
)
require
.
Equal
(
t
,
int64
(
3
),
metrics
.
AcquireReuseTotal
)
// 非 transport metrics dialer 路径
pool
.
clientDialer
=
&
openAIWSFakeDialer
{}
require
.
Equal
(
t
,
OpenAIWSTransportMetricsSnapshot
{},
pool
.
SnapshotTransportMetrics
())
pool
.
setClientDialerForTest
(
nil
)
require
.
NotNil
(
t
,
pool
.
clientDialer
)
require
.
Equal
(
t
,
8
,
nilPool
.
maxConnsHardCap
())
require
.
False
(
t
,
nilPool
.
dynamicMaxConnsEnabled
())
require
.
Equal
(
t
,
1.0
,
nilPool
.
maxConnsFactorByAccount
(
nil
))
require
.
Equal
(
t
,
0
,
nilPool
.
minIdlePerAccount
())
require
.
Equal
(
t
,
4
,
nilPool
.
maxIdlePerAccount
())
require
.
Equal
(
t
,
256
,
nilPool
.
queueLimitPerConn
())
require
.
Equal
(
t
,
0.7
,
nilPool
.
targetUtilization
())
require
.
Equal
(
t
,
time
.
Duration
(
0
),
nilPool
.
prewarmCooldown
())
require
.
Equal
(
t
,
10
*
time
.
Second
,
nilPool
.
dialTimeout
())
// shouldSuppressPrewarmLocked 覆盖 3 条分支
now
:=
time
.
Now
()
apNilFail
:=
&
openAIWSAccountPool
{
prewarmFails
:
1
}
require
.
False
(
t
,
pool
.
shouldSuppressPrewarmLocked
(
apNilFail
,
now
))
apZeroTime
:=
&
openAIWSAccountPool
{
prewarmFails
:
2
}
require
.
False
(
t
,
pool
.
shouldSuppressPrewarmLocked
(
apZeroTime
,
now
))
require
.
Equal
(
t
,
0
,
apZeroTime
.
prewarmFails
)
apOldFail
:=
&
openAIWSAccountPool
{
prewarmFails
:
2
,
prewarmFailAt
:
now
.
Add
(
-
openAIWSPrewarmFailureWindow
-
time
.
Second
)}
require
.
False
(
t
,
pool
.
shouldSuppressPrewarmLocked
(
apOldFail
,
now
))
apRecentFail
:=
&
openAIWSAccountPool
{
prewarmFails
:
openAIWSPrewarmFailureSuppress
,
prewarmFailAt
:
now
}
require
.
True
(
t
,
pool
.
shouldSuppressPrewarmLocked
(
apRecentFail
,
now
))
// recordConnPickDuration 的保护分支
nilPool
.
recordConnPickDuration
(
10
*
time
.
Millisecond
)
pool
.
recordConnPickDuration
(
-
10
*
time
.
Millisecond
)
require
.
Equal
(
t
,
int64
(
1
),
pool
.
metrics
.
connPickTotal
.
Load
())
// account pool 读写分支
require
.
Nil
(
t
,
nilPool
.
getOrCreateAccountPool
(
1
))
require
.
Nil
(
t
,
pool
.
getOrCreateAccountPool
(
0
))
pool
.
accounts
.
Store
(
int64
(
7
),
"invalid"
)
ap
:=
pool
.
getOrCreateAccountPool
(
7
)
require
.
NotNil
(
t
,
ap
)
_
,
ok
:=
pool
.
getAccountPool
(
0
)
require
.
False
(
t
,
ok
)
_
,
ok
=
pool
.
getAccountPool
(
12345
)
require
.
False
(
t
,
ok
)
pool
.
accounts
.
Store
(
int64
(
8
),
"bad-type"
)
_
,
ok
=
pool
.
getAccountPool
(
8
)
require
.
False
(
t
,
ok
)
// health check 条件
require
.
False
(
t
,
pool
.
shouldHealthCheckConn
(
nil
))
conn
:=
newOpenAIWSConn
(
"health"
,
1
,
&
openAIWSFakeConn
{},
nil
)
conn
.
lastUsedNano
.
Store
(
time
.
Now
()
.
Add
(
-
openAIWSConnHealthCheckIdle
-
time
.
Second
)
.
UnixNano
())
require
.
True
(
t
,
pool
.
shouldHealthCheckConn
(
conn
))
}
func
TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches
(
t
*
testing
.
T
)
{
var
nilConn
*
openAIWSConn
nilConn
.
touch
()
require
.
Equal
(
t
,
time
.
Time
{},
nilConn
.
createdAt
())
require
.
Equal
(
t
,
time
.
Time
{},
nilConn
.
lastUsedAt
())
require
.
Equal
(
t
,
time
.
Duration
(
0
),
nilConn
.
idleDuration
(
time
.
Now
()))
require
.
Equal
(
t
,
time
.
Duration
(
0
),
nilConn
.
age
(
time
.
Now
()))
require
.
False
(
t
,
nilConn
.
isLeased
())
require
.
False
(
t
,
nilConn
.
isPrewarmed
())
nilConn
.
markPrewarmed
()
conn
:=
newOpenAIWSConn
(
"lease_state"
,
1
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
conn
.
tryAcquire
())
require
.
True
(
t
,
conn
.
isLeased
())
conn
.
release
()
require
.
False
(
t
,
conn
.
isLeased
())
conn
.
close
()
require
.
False
(
t
,
conn
.
tryAcquire
())
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
err
:=
conn
.
acquire
(
ctx
)
require
.
Error
(
t
,
err
)
}
func
TestOpenAIWSConnLease_ReadWriteNilConnBranches
(
t
*
testing
.
T
)
{
lease
:=
&
openAIWSConnLease
{}
require
.
ErrorIs
(
t
,
lease
.
WriteJSON
(
map
[
string
]
any
{
"k"
:
"v"
},
time
.
Second
),
errOpenAIWSConnClosed
)
require
.
ErrorIs
(
t
,
lease
.
WriteJSONContext
(
context
.
Background
(),
map
[
string
]
any
{
"k"
:
"v"
}),
errOpenAIWSConnClosed
)
_
,
err
:=
lease
.
ReadMessage
(
10
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
_
,
err
=
lease
.
ReadMessageContext
(
context
.
Background
())
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
_
,
err
=
lease
.
ReadMessageWithContextTimeout
(
context
.
Background
(),
10
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
}
func
TestOpenAIWSConnLease_ReleasedLeaseGuards
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"released_guard"
,
1
,
&
openAIWSFakeConn
{},
nil
)
lease
:=
&
openAIWSConnLease
{
conn
:
conn
}
require
.
NoError
(
t
,
lease
.
PingWithTimeout
(
50
*
time
.
Millisecond
))
lease
.
Release
()
lease
.
Release
()
// idempotent
require
.
ErrorIs
(
t
,
lease
.
WriteJSON
(
map
[
string
]
any
{
"k"
:
"v"
},
time
.
Second
),
errOpenAIWSConnClosed
)
require
.
ErrorIs
(
t
,
lease
.
WriteJSONContext
(
context
.
Background
(),
map
[
string
]
any
{
"k"
:
"v"
}),
errOpenAIWSConnClosed
)
require
.
ErrorIs
(
t
,
lease
.
WriteJSONWithContextTimeout
(
context
.
Background
(),
map
[
string
]
any
{
"k"
:
"v"
},
time
.
Second
),
errOpenAIWSConnClosed
)
_
,
err
:=
lease
.
ReadMessage
(
10
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
_
,
err
=
lease
.
ReadMessageContext
(
context
.
Background
())
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
_
,
err
=
lease
.
ReadMessageWithContextTimeout
(
context
.
Background
(),
10
*
time
.
Millisecond
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
require
.
ErrorIs
(
t
,
lease
.
PingWithTimeout
(
50
*
time
.
Millisecond
),
errOpenAIWSConnClosed
)
}
func
TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction
(
t
*
testing
.
T
)
{
conn
:=
newOpenAIWSConn
(
"released_markbroken"
,
7
,
&
openAIWSFakeConn
{},
nil
)
ap
:=
&
openAIWSAccountPool
{
conns
:
map
[
string
]
*
openAIWSConn
{
conn
.
id
:
conn
,
},
}
pool
:=
&
openAIWSConnPool
{}
pool
.
accounts
.
Store
(
int64
(
7
),
ap
)
lease
:=
&
openAIWSConnLease
{
pool
:
pool
,
accountID
:
7
,
conn
:
conn
,
}
lease
.
Release
()
lease
.
MarkBroken
()
ap
.
mu
.
Lock
()
_
,
exists
:=
ap
.
conns
[
conn
.
id
]
ap
.
mu
.
Unlock
()
require
.
True
(
t
,
exists
,
"released lease should not evict active pool connection"
)
}
func
TestOpenAIWSConn_AdditionalGuardBranches
(
t
*
testing
.
T
)
{
var
nilConn
*
openAIWSConn
require
.
False
(
t
,
nilConn
.
tryAcquire
())
require
.
ErrorIs
(
t
,
nilConn
.
acquire
(
context
.
Background
()),
errOpenAIWSConnClosed
)
nilConn
.
release
()
nilConn
.
close
()
require
.
Equal
(
t
,
""
,
nilConn
.
handshakeHeader
(
"x-test"
))
connBusy
:=
newOpenAIWSConn
(
"busy_ctx"
,
1
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
connBusy
.
tryAcquire
())
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
require
.
ErrorIs
(
t
,
connBusy
.
acquire
(
ctx
),
context
.
Canceled
)
connBusy
.
release
()
connClosed
:=
newOpenAIWSConn
(
"closed_guard"
,
1
,
&
openAIWSFakeConn
{},
nil
)
connClosed
.
close
()
require
.
ErrorIs
(
t
,
connClosed
.
writeJSONWithTimeout
(
context
.
Background
(),
map
[
string
]
any
{
"k"
:
"v"
},
time
.
Second
),
errOpenAIWSConnClosed
,
)
_
,
err
:=
connClosed
.
readMessageWithContextTimeout
(
context
.
Background
(),
time
.
Second
)
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
require
.
ErrorIs
(
t
,
connClosed
.
pingWithTimeout
(
time
.
Second
),
errOpenAIWSConnClosed
)
connNoWS
:=
newOpenAIWSConn
(
"no_ws"
,
1
,
nil
,
nil
)
require
.
ErrorIs
(
t
,
connNoWS
.
writeJSON
(
map
[
string
]
any
{
"k"
:
"v"
},
context
.
Background
()),
errOpenAIWSConnClosed
)
_
,
err
=
connNoWS
.
readMessage
(
context
.
Background
())
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
require
.
ErrorIs
(
t
,
connNoWS
.
pingWithTimeout
(
time
.
Second
),
errOpenAIWSConnClosed
)
require
.
Equal
(
t
,
""
,
connNoWS
.
handshakeHeader
(
"x-test"
))
connOK
:=
newOpenAIWSConn
(
"ok"
,
1
,
&
openAIWSFakeConn
{},
nil
)
require
.
NoError
(
t
,
connOK
.
writeJSON
(
map
[
string
]
any
{
"k"
:
"v"
},
nil
))
_
,
err
=
connOK
.
readMessageWithContextTimeout
(
context
.
Background
(),
0
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
connOK
.
pingWithTimeout
(
0
))
connZero
:=
newOpenAIWSConn
(
"zero_ts"
,
1
,
&
openAIWSFakeConn
{},
nil
)
connZero
.
createdAtNano
.
Store
(
0
)
connZero
.
lastUsedNano
.
Store
(
0
)
require
.
True
(
t
,
connZero
.
createdAt
()
.
IsZero
())
require
.
True
(
t
,
connZero
.
lastUsedAt
()
.
IsZero
())
require
.
Equal
(
t
,
time
.
Duration
(
0
),
connZero
.
idleDuration
(
time
.
Now
()))
require
.
Equal
(
t
,
time
.
Duration
(
0
),
connZero
.
age
(
time
.
Now
()))
require
.
Nil
(
t
,
cloneOpenAIWSAcquireRequestPtr
(
nil
))
copied
:=
cloneHeader
(
http
.
Header
{
"X-Empty"
:
[]
string
{},
"X-Test"
:
[]
string
{
"v1"
},
})
require
.
Contains
(
t
,
copied
,
"X-Empty"
)
require
.
Nil
(
t
,
copied
[
"X-Empty"
])
require
.
Equal
(
t
,
"v1"
,
copied
.
Get
(
"X-Test"
))
closeOpenAIWSConns
([]
*
openAIWSConn
{
nil
,
connOK
})
}
func
TestOpenAIWSConnLease_MarkBrokenEvictsConn
(
t
*
testing
.
T
)
{
pool
:=
newOpenAIWSConnPool
(
&
config
.
Config
{})
accountID
:=
int64
(
5001
)
conn
:=
newOpenAIWSConn
(
"broken_me"
,
accountID
,
&
openAIWSFakeConn
{},
nil
)
ap
:=
pool
.
getOrCreateAccountPool
(
accountID
)
ap
.
mu
.
Lock
()
ap
.
conns
[
conn
.
id
]
=
conn
ap
.
mu
.
Unlock
()
lease
:=
&
openAIWSConnLease
{
pool
:
pool
,
accountID
:
accountID
,
conn
:
conn
,
}
lease
.
MarkBroken
()
ap
.
mu
.
Lock
()
_
,
exists
:=
ap
.
conns
[
conn
.
id
]
ap
.
mu
.
Unlock
()
require
.
False
(
t
,
exists
)
require
.
False
(
t
,
conn
.
tryAcquire
(),
"被标记为 broken 的连接应被关闭"
)
}
func
TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
pool
:=
newOpenAIWSConnPool
(
cfg
)
require
.
Equal
(
t
,
0
,
pool
.
targetConnCountLocked
(
nil
,
1
))
ap
:=
&
openAIWSAccountPool
{
conns
:
map
[
string
]
*
openAIWSConn
{}}
require
.
Equal
(
t
,
0
,
pool
.
targetConnCountLocked
(
ap
,
0
))
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
3
require
.
Equal
(
t
,
1
,
pool
.
targetConnCountLocked
(
ap
,
1
),
"minIdle 应被 maxConns 截断"
)
// 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
cfg
.
Gateway
.
OpenAIWS
.
MinIdlePerAccount
=
0
cfg
.
Gateway
.
OpenAIWS
.
PoolTargetUtilization
=
0.9
busy
:=
newOpenAIWSConn
(
"busy_target"
,
2
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
busy
.
tryAcquire
())
busy
.
waiters
.
Store
(
1
)
ap
.
conns
[
busy
.
id
]
=
busy
target
:=
pool
.
targetConnCountLocked
(
ap
,
4
)
require
.
GreaterOrEqual
(
t
,
target
,
len
(
ap
.
conns
)
+
1
)
// prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
req
:=
openAIWSAcquireRequest
{
Account
:
&
Account
{
ID
:
999
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
},
WSURL
:
"wss://example.com/v1/responses"
,
}
pool
.
prewarmConns
(
999
,
req
,
1
)
// prewarm: 拨号失败分支(prewarmFails 累加)
accountID
:=
int64
(
1000
)
failPool
:=
newOpenAIWSConnPool
(
cfg
)
failPool
.
setClientDialerForTest
(
&
openAIWSAlwaysFailDialer
{})
apFail
:=
failPool
.
getOrCreateAccountPool
(
accountID
)
apFail
.
mu
.
Lock
()
apFail
.
creating
=
1
apFail
.
mu
.
Unlock
()
req
.
Account
.
ID
=
accountID
failPool
.
prewarmConns
(
accountID
,
req
,
1
)
apFail
.
mu
.
Lock
()
require
.
GreaterOrEqual
(
t
,
apFail
.
prewarmFails
,
1
)
apFail
.
mu
.
Unlock
()
}
func
TestOpenAIWSConnPool_Acquire_ErrorBranches
(
t
*
testing
.
T
)
{
var
nilPool
*
openAIWSConnPool
_
,
err
:=
nilPool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{})
require
.
Error
(
t
,
err
)
pool
:=
newOpenAIWSConnPool
(
&
config
.
Config
{})
_
,
err
=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
&
Account
{
ID
:
1
},
WSURL
:
" "
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"ws url is empty"
)
// target=nil 分支:池满且仅有 nil 连接
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
1
cfg
.
Gateway
.
OpenAIWS
.
QueueLimitPerConn
=
1
fullPool
:=
newOpenAIWSConnPool
(
cfg
)
account
:=
&
Account
{
ID
:
2001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap
:=
fullPool
.
getOrCreateAccountPool
(
account
.
ID
)
ap
.
mu
.
Lock
()
ap
.
conns
[
"nil"
]
=
nil
ap
.
lastCleanupAt
=
time
.
Now
()
ap
.
mu
.
Unlock
()
_
,
err
=
fullPool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
})
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnClosed
)
// queue full 分支:waiters 达上限
account2
:=
&
Account
{
ID
:
2002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
ap2
:=
fullPool
.
getOrCreateAccountPool
(
account2
.
ID
)
conn
:=
newOpenAIWSConn
(
"queue_full"
,
account2
.
ID
,
&
openAIWSFakeConn
{},
nil
)
require
.
True
(
t
,
conn
.
tryAcquire
())
conn
.
waiters
.
Store
(
1
)
ap2
.
mu
.
Lock
()
ap2
.
conns
[
conn
.
id
]
=
conn
ap2
.
lastCleanupAt
=
time
.
Now
()
ap2
.
mu
.
Unlock
()
_
,
err
=
fullPool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account2
,
WSURL
:
"wss://example.com/v1/responses"
,
})
require
.
ErrorIs
(
t
,
err
,
errOpenAIWSConnQueueFull
)
}
type
openAIWSFakeDialer
struct
{}
func
(
d
*
openAIWSFakeDialer
)
Dial
(
ctx
context
.
Context
,
wsURL
string
,
headers
http
.
Header
,
proxyURL
string
,
)
(
openAIWSClientConn
,
int
,
http
.
Header
,
error
)
{
_
=
ctx
_
=
wsURL
_
=
headers
_
=
proxyURL
return
&
openAIWSFakeConn
{},
0
,
nil
,
nil
}
type
openAIWSCountingDialer
struct
{
mu
sync
.
Mutex
dialCount
int
}
type
openAIWSAlwaysFailDialer
struct
{
mu
sync
.
Mutex
dialCount
int
}
type
openAIWSPingBlockingConn
struct
{
current
*
atomic
.
Int32
maxConcurrent
*
atomic
.
Int32
release
<-
chan
struct
{}
}
func
(
c
*
openAIWSPingBlockingConn
)
WriteJSON
(
context
.
Context
,
any
)
error
{
return
nil
}
func
(
c
*
openAIWSPingBlockingConn
)
ReadMessage
(
context
.
Context
)
([]
byte
,
error
)
{
return
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`
),
nil
}
func
(
c
*
openAIWSPingBlockingConn
)
Ping
(
ctx
context
.
Context
)
error
{
if
c
.
current
==
nil
||
c
.
maxConcurrent
==
nil
{
return
nil
}
now
:=
c
.
current
.
Add
(
1
)
for
{
prev
:=
c
.
maxConcurrent
.
Load
()
if
now
<=
prev
||
c
.
maxConcurrent
.
CompareAndSwap
(
prev
,
now
)
{
break
}
}
defer
c
.
current
.
Add
(
-
1
)
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
case
<-
c
.
release
:
return
nil
}
}
func
(
c
*
openAIWSPingBlockingConn
)
Close
()
error
{
return
nil
}
func
(
d
*
openAIWSCountingDialer
)
Dial
(
ctx
context
.
Context
,
wsURL
string
,
headers
http
.
Header
,
proxyURL
string
,
)
(
openAIWSClientConn
,
int
,
http
.
Header
,
error
)
{
_
=
ctx
_
=
wsURL
_
=
headers
_
=
proxyURL
d
.
mu
.
Lock
()
d
.
dialCount
++
d
.
mu
.
Unlock
()
return
&
openAIWSFakeConn
{},
0
,
nil
,
nil
}
func
(
d
*
openAIWSCountingDialer
)
DialCount
()
int
{
d
.
mu
.
Lock
()
defer
d
.
mu
.
Unlock
()
return
d
.
dialCount
}
func
(
d
*
openAIWSAlwaysFailDialer
)
Dial
(
ctx
context
.
Context
,
wsURL
string
,
headers
http
.
Header
,
proxyURL
string
,
)
(
openAIWSClientConn
,
int
,
http
.
Header
,
error
)
{
_
=
ctx
_
=
wsURL
_
=
headers
_
=
proxyURL
d
.
mu
.
Lock
()
d
.
dialCount
++
d
.
mu
.
Unlock
()
return
nil
,
503
,
nil
,
errors
.
New
(
"dial failed"
)
}
func
(
d
*
openAIWSAlwaysFailDialer
)
DialCount
()
int
{
d
.
mu
.
Lock
()
defer
d
.
mu
.
Unlock
()
return
d
.
dialCount
}
type
openAIWSFakeConn
struct
{
mu
sync
.
Mutex
closed
bool
payload
[][]
byte
}
func
(
c
*
openAIWSFakeConn
)
WriteJSON
(
ctx
context
.
Context
,
value
any
)
error
{
_
=
ctx
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
if
c
.
closed
{
return
errors
.
New
(
"closed"
)
}
c
.
payload
=
append
(
c
.
payload
,
[]
byte
(
"ok"
))
_
=
value
return
nil
}
func
(
c
*
openAIWSFakeConn
)
ReadMessage
(
ctx
context
.
Context
)
([]
byte
,
error
)
{
_
=
ctx
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
if
c
.
closed
{
return
nil
,
errors
.
New
(
"closed"
)
}
return
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_fake"}}`
),
nil
}
func
(
c
*
openAIWSFakeConn
)
Ping
(
ctx
context
.
Context
)
error
{
_
=
ctx
return
nil
}
func
(
c
*
openAIWSFakeConn
)
Close
()
error
{
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
c
.
closed
=
true
return
nil
}
type
openAIWSBlockingConn
struct
{
readDelay
time
.
Duration
}
func
(
c
*
openAIWSBlockingConn
)
WriteJSON
(
ctx
context
.
Context
,
value
any
)
error
{
_
=
ctx
_
=
value
return
nil
}
func
(
c
*
openAIWSBlockingConn
)
ReadMessage
(
ctx
context
.
Context
)
([]
byte
,
error
)
{
delay
:=
c
.
readDelay
if
delay
<=
0
{
delay
=
10
*
time
.
Millisecond
}
timer
:=
time
.
NewTimer
(
delay
)
defer
timer
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
return
nil
,
ctx
.
Err
()
case
<-
timer
.
C
:
return
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_blocking"}}`
),
nil
}
}
func
(
c
*
openAIWSBlockingConn
)
Ping
(
ctx
context
.
Context
)
error
{
_
=
ctx
return
nil
}
func
(
c
*
openAIWSBlockingConn
)
Close
()
error
{
return
nil
}
type
openAIWSWriteBlockingConn
struct
{}
func
(
c
*
openAIWSWriteBlockingConn
)
WriteJSON
(
ctx
context
.
Context
,
_
any
)
error
{
<-
ctx
.
Done
()
return
ctx
.
Err
()
}
func
(
c
*
openAIWSWriteBlockingConn
)
ReadMessage
(
context
.
Context
)
([]
byte
,
error
)
{
return
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_write_block"}}`
),
nil
}
func
(
c
*
openAIWSWriteBlockingConn
)
Ping
(
context
.
Context
)
error
{
return
nil
}
func
(
c
*
openAIWSWriteBlockingConn
)
Close
()
error
{
return
nil
}
type
openAIWSPingFailConn
struct
{}
func
(
c
*
openAIWSPingFailConn
)
WriteJSON
(
context
.
Context
,
any
)
error
{
return
nil
}
func
(
c
*
openAIWSPingFailConn
)
ReadMessage
(
context
.
Context
)
([]
byte
,
error
)
{
return
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`
),
nil
}
func
(
c
*
openAIWSPingFailConn
)
Ping
(
context
.
Context
)
error
{
return
errors
.
New
(
"ping failed"
)
}
func
(
c
*
openAIWSPingFailConn
)
Close
()
error
{
return
nil
}
type
openAIWSContextProbeConn
struct
{
lastWriteCtx
context
.
Context
}
func
(
c
*
openAIWSContextProbeConn
)
WriteJSON
(
ctx
context
.
Context
,
_
any
)
error
{
c
.
lastWriteCtx
=
ctx
return
nil
}
func
(
c
*
openAIWSContextProbeConn
)
ReadMessage
(
context
.
Context
)
([]
byte
,
error
)
{
return
[]
byte
(
`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`
),
nil
}
func
(
c
*
openAIWSContextProbeConn
)
Ping
(
context
.
Context
)
error
{
return
nil
}
func
(
c
*
openAIWSContextProbeConn
)
Close
()
error
{
return
nil
}
type
openAIWSNilConnDialer
struct
{}
func
(
d
*
openAIWSNilConnDialer
)
Dial
(
ctx
context
.
Context
,
wsURL
string
,
headers
http
.
Header
,
proxyURL
string
,
)
(
openAIWSClientConn
,
int
,
http
.
Header
,
error
)
{
_
=
ctx
_
=
wsURL
_
=
headers
_
=
proxyURL
return
nil
,
200
,
nil
,
nil
}
func
TestOpenAIWSConnPool_DialConnNilConnection
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
MaxConnsPerAccount
=
2
cfg
.
Gateway
.
OpenAIWS
.
DialTimeoutSeconds
=
1
pool
:=
newOpenAIWSConnPool
(
cfg
)
pool
.
setClientDialerForTest
(
&
openAIWSNilConnDialer
{})
account
:=
&
Account
{
ID
:
91
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
pool
.
Acquire
(
context
.
Background
(),
openAIWSAcquireRequest
{
Account
:
account
,
WSURL
:
"wss://example.com/v1/responses"
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"nil connection"
)
}
func
TestOpenAIWSConnPool_SnapshotTransportMetrics
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
pool
:=
newOpenAIWSConnPool
(
cfg
)
dialer
,
ok
:=
pool
.
clientDialer
.
(
*
coderOpenAIWSClientDialer
)
require
.
True
(
t
,
ok
)
_
,
err
:=
dialer
.
proxyHTTPClient
(
"http://127.0.0.1:28080"
)
require
.
NoError
(
t
,
err
)
_
,
err
=
dialer
.
proxyHTTPClient
(
"http://127.0.0.1:28080"
)
require
.
NoError
(
t
,
err
)
_
,
err
=
dialer
.
proxyHTTPClient
(
"http://127.0.0.1:28081"
)
require
.
NoError
(
t
,
err
)
snapshot
:=
pool
.
SnapshotTransportMetrics
()
require
.
Equal
(
t
,
int64
(
1
),
snapshot
.
ProxyClientCacheHits
)
require
.
Equal
(
t
,
int64
(
2
),
snapshot
.
ProxyClientCacheMisses
)
require
.
InDelta
(
t
,
1.0
/
3.0
,
snapshot
.
TransportReuseRatio
,
0.0001
)
}
backend/internal/service/openai_ws_protocol_forward_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
wsFallbackServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
http
.
NotFound
(
w
,
r
)
}))
defer
wsFallbackServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsFallbackServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 模式下失败时不应回退 HTTP"
)
}
func
TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
wsFallbackServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
http
.
NotFound
(
w
,
r
)
}))
defer
wsFallbackServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
SetOpenAIClientTransport
(
c
,
OpenAIClientTransportHTTP
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
101
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsFallbackServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
OpenAIWSMode
,
"HTTP 入站应保持 HTTP 转发"
)
require
.
NotNil
(
t
,
upstream
.
lastReq
,
"HTTP 入站应命中 HTTP 上游"
)
require
.
False
(
t
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"previous_response_id"
)
.
Exists
(),
"HTTP 路径应沿用原逻辑移除 previous_response_id"
)
decision
,
_
:=
c
.
Get
(
"openai_ws_transport_decision"
)
reason
,
_
:=
c
.
Get
(
"openai_ws_transport_reason"
)
require
.
Equal
(
t
,
string
(
OpenAIUpstreamTransportHTTPSSE
),
decision
)
require
.
Equal
(
t
,
"client_protocol_http"
,
reason
)
}
func
TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
wsFallbackServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
http
.
NotFound
(
w
,
r
)
}))
defer
wsFallbackServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
false
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsFallbackServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"previous_response_id"
)
.
Exists
())
}
func
TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ws426Server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusUpgradeRequired
)
_
,
_
=
w
.
Write
([]
byte
(
`upgrade required`
))
}))
defer
ws426Server
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
12
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ws426Server
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"upgrade_required"
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 模式下不应再回退 HTTP"
)
require
.
Equal
(
t
,
http
.
StatusUpgradeRequired
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"426"
)
}
func
TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
http
.
NotFound
(
w
,
r
)
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
30
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
21
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
svc
.
markOpenAIWSFallbackCooling
(
account
.
ID
,
"upgrade_required"
)
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 模式下不应再回退 HTTP"
)
_
,
ok
:=
c
.
Get
(
"openai_ws_fallback_cooling"
)
require
.
False
(
t
,
ok
,
"已移除 fallback cooling 快捷回退路径"
)
}
func
TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsockets
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
false
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
31
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
"https://api.openai.com/v1/responses"
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"ws v1"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"WSv1"
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WSv1 不支持时不应触发 HTTP 上游请求"
)
}
func
TestNewOpenAIGatewayService_InitializesOpenAIWSResolver
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewOpenAIGatewayService
(
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
decision
:=
svc
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
nil
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"account_missing"
,
decision
.
Reason
)
}
func
TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ws426Server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusUpgradeRequired
)
_
,
_
=
w
.
Write
([]
byte
(
`upgrade required`
))
}))
defer
ws426Server
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
c
.
String
(
http
.
StatusAccepted
,
"already-written"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
}
account
:=
&
Account
{
ID
:
41
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ws426Server
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"ws fallback"
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"已写下游响应时,不应再回退 HTTP"
)
}
func
TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
// 仅发送 response.created(非 token 事件)后立即关闭,
// 模拟线上“上游早期内部错误断连”的场景。
if
err
:=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.created"
,
"response"
:
map
[
string
]
any
{
"id"
:
"resp_ws_created_only"
,
"model"
:
"gpt-5.3-codex"
,
},
});
err
!=
nil
{
t
.
Errorf
(
"write response.created failed: %v"
,
err
)
return
}
closePayload
:=
websocket
.
FormatCloseMessage
(
websocket
.
CloseInternalServerErr
,
""
)
_
=
conn
.
WriteControl
(
websocket
.
CloseMessage
,
closePayload
,
time
.
Now
()
.
Add
(
time
.
Second
))
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
"data: {
\"
type
\"
:
\"
response.output_text.delta
\"
,
\"
delta
\"
:
\"
ok
\"
}
\n\n
"
+
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
id
\"
:
\"
resp_http_fallback
\"
,
\"
usage
\"
:{
\"
input_tokens
\"
:2,
\"
output_tokens
\"
:1}}}
\n\n
"
+
"data: [DONE]
\n\n
"
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
88
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 早期断连后不应再回退 HTTP"
)
require
.
Empty
(
t
,
rec
.
Body
.
String
(),
"未产出 token 前上游断连时不应写入下游半截流"
)
}
func
TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
closePayload
:=
websocket
.
FormatCloseMessage
(
websocket
.
CloseInternalServerErr
,
""
)
_
=
conn
.
WriteControl
(
websocket
.
CloseMessage
,
closePayload
,
time
.
Now
()
.
Add
(
time
.
Second
))
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
"data: {
\"
type
\"
:
\"
response.output_text.delta
\"
,
\"
delta
\"
:
\"
ok
\"
}
\n\n
"
+
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
id
\"
:
\"
resp_retry_http_fallback
\"
,
\"
usage
\"
:{
\"
input_tokens
\"
:2,
\"
output_tokens
\"
:1}}}
\n\n
"
+
"data: [DONE]
\n\n
"
,
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
89
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 重连耗尽后不应再回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
openAIWSReconnectRetryLimit
+
1
),
wsAttempts
.
Load
())
}
func
TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
closePayload
:=
websocket
.
FormatCloseMessage
(
websocket
.
ClosePolicyViolation
,
""
)
_
=
conn
.
WriteControl
(
websocket
.
CloseMessage
,
closePayload
,
time
.
Now
()
.
Add
(
time
.
Second
))
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
cfg
.
Gateway
.
OpenAIWS
.
RetryBackoffInitialMS
=
1
cfg
.
Gateway
.
OpenAIWS
.
RetryBackoffMaxMS
=
2
cfg
.
Gateway
.
OpenAIWS
.
RetryJitterRatio
=
0
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
8901
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"策略违规关闭后不应回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
1
),
wsAttempts
.
Load
(),
"策略违规不应进行 WS 重试"
)
}
func
TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
_
=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"error"
,
"error"
:
map
[
string
]
any
{
"code"
:
"websocket_connection_limit_reached"
,
"type"
:
"server_error"
,
"message"
:
"websocket connection limit reached"
,
},
})
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
90
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"触发 websocket_connection_limit_reached 后不应回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
openAIWSReconnectRetryLimit
+
1
),
wsAttempts
.
Load
())
}
func
TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
var
wsRequestPayloads
[][]
byte
var
wsRequestMu
sync
.
Mutex
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
attempt
:=
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
reqRaw
,
_
:=
json
.
Marshal
(
req
)
wsRequestMu
.
Lock
()
wsRequestPayloads
=
append
(
wsRequestPayloads
,
reqRaw
)
wsRequestMu
.
Unlock
()
if
attempt
==
1
{
_
=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"error"
,
"error"
:
map
[
string
]
any
{
"code"
:
"previous_response_not_found"
,
"type"
:
"invalid_request_error"
,
"message"
:
"previous response not found"
,
},
})
return
}
_
=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"response.completed"
,
"response"
:
map
[
string
]
any
{
"id"
:
"resp_ws_prev_recover_ok"
,
"model"
:
"gpt-5.3-codex"
,
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
1
,
"output_tokens"
:
1
,
"input_tokens_details"
:
map
[
string
]
any
{
"cached_tokens"
:
0
,
},
},
},
})
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
91
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_ws_prev_recover_ok"
,
result
.
RequestID
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"previous_response_not_found 不应回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
2
),
wsAttempts
.
Load
(),
"previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试"
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"resp_ws_prev_recover_ok"
,
gjson
.
Get
(
rec
.
Body
.
String
(),
"id"
)
.
String
())
wsRequestMu
.
Lock
()
requests
:=
append
([][]
byte
(
nil
),
wsRequestPayloads
...
)
wsRequestMu
.
Unlock
()
require
.
Len
(
t
,
requests
,
2
)
require
.
True
(
t
,
gjson
.
GetBytes
(
requests
[
0
],
"previous_response_id"
)
.
Exists
(),
"首轮请求应保留 previous_response_id"
)
require
.
False
(
t
,
gjson
.
GetBytes
(
requests
[
1
],
"previous_response_id"
)
.
Exists
(),
"恢复重试应移除 previous_response_id"
)
}
func
TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
var
wsRequestPayloads
[][]
byte
var
wsRequestMu
sync
.
Mutex
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
reqRaw
,
_
:=
json
.
Marshal
(
req
)
wsRequestMu
.
Lock
()
wsRequestPayloads
=
append
(
wsRequestPayloads
,
reqRaw
)
wsRequestMu
.
Unlock
()
_
=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"error"
,
"error"
:
map
[
string
]
any
{
"code"
:
"previous_response_not_found"
,
"type"
:
"invalid_request_error"
,
"message"
:
"previous response not found"
,
},
})
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
92
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"previous_response_not_found 不应回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
1
),
wsAttempts
.
Load
(),
"function_call_output 场景应跳过 previous_response_not_found 自动恢复"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
strings
.
ToLower
(
rec
.
Body
.
String
()),
"previous response not found"
)
wsRequestMu
.
Lock
()
requests
:=
append
([][]
byte
(
nil
),
wsRequestPayloads
...
)
wsRequestMu
.
Unlock
()
require
.
Len
(
t
,
requests
,
1
)
require
.
True
(
t
,
gjson
.
GetBytes
(
requests
[
0
],
"previous_response_id"
)
.
Exists
())
}
func
TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
var
wsRequestPayloads
[][]
byte
var
wsRequestMu
sync
.
Mutex
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
reqRaw
,
_
:=
json
.
Marshal
(
req
)
wsRequestMu
.
Lock
()
wsRequestPayloads
=
append
(
wsRequestPayloads
,
reqRaw
)
wsRequestMu
.
Unlock
()
_
=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"error"
,
"error"
:
map
[
string
]
any
{
"code"
:
"previous_response_not_found"
,
"type"
:
"invalid_request_error"
,
"message"
:
"previous response not found"
,
},
})
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
93
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 模式下 previous_response_not_found 不应回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
1
),
wsAttempts
.
Load
(),
"缺少 previous_response_id 时应跳过自动恢复重试"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
wsRequestMu
.
Lock
()
requests
:=
append
([][]
byte
(
nil
),
wsRequestPayloads
...
)
wsRequestMu
.
Unlock
()
require
.
Len
(
t
,
requests
,
1
)
require
.
False
(
t
,
gjson
.
GetBytes
(
requests
[
0
],
"previous_response_id"
)
.
Exists
())
}
func
TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
var
wsAttempts
atomic
.
Int32
var
wsRequestPayloads
[][]
byte
var
wsRequestMu
sync
.
Mutex
upgrader
:=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
true
}}
wsServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
wsAttempts
.
Add
(
1
)
conn
,
err
:=
upgrader
.
Upgrade
(
w
,
r
,
nil
)
if
err
!=
nil
{
t
.
Errorf
(
"upgrade websocket failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
var
req
map
[
string
]
any
if
err
:=
conn
.
ReadJSON
(
&
req
);
err
!=
nil
{
t
.
Errorf
(
"read ws request failed: %v"
,
err
)
return
}
reqRaw
,
_
:=
json
.
Marshal
(
req
)
wsRequestMu
.
Lock
()
wsRequestPayloads
=
append
(
wsRequestPayloads
,
reqRaw
)
wsRequestMu
.
Unlock
()
_
=
conn
.
WriteJSON
(
map
[
string
]
any
{
"type"
:
"error"
,
"error"
:
map
[
string
]
any
{
"code"
:
"previous_response_not_found"
,
"type"
:
"invalid_request_error"
,
"message"
:
"previous response not found"
,
},
})
}))
defer
wsServer
.
Close
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"custom-client/1.0"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`
)),
},
}
cfg
:=
&
config
.
Config
{}
cfg
.
Security
.
URLAllowlist
.
Enabled
=
false
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
=
true
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
FallbackCooldownSeconds
=
1
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
,
httpUpstream
:
upstream
,
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
toolCorrector
:
NewCodexToolCorrector
(),
}
account
:=
&
Account
{
ID
:
94
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
wsServer
.
URL
,
},
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
body
:=
[]
byte
(
`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
upstream
.
lastReq
,
"WS 模式下 previous_response_not_found 不应回退 HTTP"
)
require
.
Equal
(
t
,
int32
(
2
),
wsAttempts
.
Load
(),
"应只允许一次自动恢复重试"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
wsRequestMu
.
Lock
()
requests
:=
append
([][]
byte
(
nil
),
wsRequestPayloads
...
)
wsRequestMu
.
Unlock
()
require
.
Len
(
t
,
requests
,
2
)
require
.
True
(
t
,
gjson
.
GetBytes
(
requests
[
0
],
"previous_response_id"
)
.
Exists
(),
"首轮请求应包含 previous_response_id"
)
require
.
False
(
t
,
gjson
.
GetBytes
(
requests
[
1
],
"previous_response_id"
)
.
Exists
(),
"恢复重试应移除 previous_response_id"
)
}
backend/internal/service/openai_ws_protocol_resolver.go
0 → 100644
View file @
bb664d9b
package
service
import
"github.com/Wei-Shaw/sub2api/internal/config"
// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
type
OpenAIUpstreamTransport
string
const
(
OpenAIUpstreamTransportAny
OpenAIUpstreamTransport
=
""
OpenAIUpstreamTransportHTTPSSE
OpenAIUpstreamTransport
=
"http_sse"
OpenAIUpstreamTransportResponsesWebsocket
OpenAIUpstreamTransport
=
"responses_websockets"
OpenAIUpstreamTransportResponsesWebsocketV2
OpenAIUpstreamTransport
=
"responses_websockets_v2"
)
// OpenAIWSProtocolDecision 表示协议决策结果。
type
OpenAIWSProtocolDecision
struct
{
Transport
OpenAIUpstreamTransport
Reason
string
}
// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
type
OpenAIWSProtocolResolver
interface
{
Resolve
(
account
*
Account
)
OpenAIWSProtocolDecision
}
type
defaultOpenAIWSProtocolResolver
struct
{
cfg
*
config
.
Config
}
// NewOpenAIWSProtocolResolver 创建默认协议决策器。
func
NewOpenAIWSProtocolResolver
(
cfg
*
config
.
Config
)
OpenAIWSProtocolResolver
{
return
&
defaultOpenAIWSProtocolResolver
{
cfg
:
cfg
}
}
func
(
r
*
defaultOpenAIWSProtocolResolver
)
Resolve
(
account
*
Account
)
OpenAIWSProtocolDecision
{
if
account
==
nil
{
return
openAIWSHTTPDecision
(
"account_missing"
)
}
if
!
account
.
IsOpenAI
()
{
return
openAIWSHTTPDecision
(
"platform_not_openai"
)
}
if
account
.
IsOpenAIWSForceHTTPEnabled
()
{
return
openAIWSHTTPDecision
(
"account_force_http"
)
}
if
r
==
nil
||
r
.
cfg
==
nil
{
return
openAIWSHTTPDecision
(
"config_missing"
)
}
wsCfg
:=
r
.
cfg
.
Gateway
.
OpenAIWS
if
wsCfg
.
ForceHTTP
{
return
openAIWSHTTPDecision
(
"global_force_http"
)
}
if
!
wsCfg
.
Enabled
{
return
openAIWSHTTPDecision
(
"global_disabled"
)
}
if
account
.
IsOpenAIOAuth
()
{
if
!
wsCfg
.
OAuthEnabled
{
return
openAIWSHTTPDecision
(
"oauth_disabled"
)
}
}
else
if
account
.
IsOpenAIApiKey
()
{
if
!
wsCfg
.
APIKeyEnabled
{
return
openAIWSHTTPDecision
(
"apikey_disabled"
)
}
}
else
{
return
openAIWSHTTPDecision
(
"unknown_auth_type"
)
}
if
wsCfg
.
ModeRouterV2Enabled
{
mode
:=
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
wsCfg
.
IngressModeDefault
)
switch
mode
{
case
OpenAIWSIngressModeOff
:
return
openAIWSHTTPDecision
(
"account_mode_off"
)
case
OpenAIWSIngressModeShared
,
OpenAIWSIngressModeDedicated
:
// continue
default
:
return
openAIWSHTTPDecision
(
"account_mode_off"
)
}
if
account
.
Concurrency
<=
0
{
return
openAIWSHTTPDecision
(
"account_concurrency_invalid"
)
}
if
wsCfg
.
ResponsesWebsocketsV2
{
return
OpenAIWSProtocolDecision
{
Transport
:
OpenAIUpstreamTransportResponsesWebsocketV2
,
Reason
:
"ws_v2_mode_"
+
mode
,
}
}
if
wsCfg
.
ResponsesWebsockets
{
return
OpenAIWSProtocolDecision
{
Transport
:
OpenAIUpstreamTransportResponsesWebsocket
,
Reason
:
"ws_v1_mode_"
+
mode
,
}
}
return
openAIWSHTTPDecision
(
"feature_disabled"
)
}
if
!
account
.
IsOpenAIResponsesWebSocketV2Enabled
()
{
return
openAIWSHTTPDecision
(
"account_disabled"
)
}
if
wsCfg
.
ResponsesWebsocketsV2
{
return
OpenAIWSProtocolDecision
{
Transport
:
OpenAIUpstreamTransportResponsesWebsocketV2
,
Reason
:
"ws_v2_enabled"
,
}
}
if
wsCfg
.
ResponsesWebsockets
{
return
OpenAIWSProtocolDecision
{
Transport
:
OpenAIUpstreamTransportResponsesWebsocket
,
Reason
:
"ws_v1_enabled"
,
}
}
return
openAIWSHTTPDecision
(
"feature_disabled"
)
}
func
openAIWSHTTPDecision
(
reason
string
)
OpenAIWSProtocolDecision
{
return
OpenAIWSProtocolDecision
{
Transport
:
OpenAIUpstreamTransportHTTPSSE
,
Reason
:
reason
,
}
}
backend/internal/service/openai_ws_protocol_resolver_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestOpenAIWSProtocolResolver_Resolve
(
t
*
testing
.
T
)
{
baseCfg
:=
&
config
.
Config
{}
baseCfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
baseCfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
baseCfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
baseCfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsockets
=
false
baseCfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
openAIOAuthEnabled
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_enabled"
:
true
,
},
}
t
.
Run
(
"v2优先"
,
func
(
t
*
testing
.
T
)
{
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
openAIOAuthEnabled
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_enabled"
,
decision
.
Reason
)
})
t
.
Run
(
"v2关闭时回退v1"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
*
baseCfg
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
false
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsockets
=
true
decision
:=
NewOpenAIWSProtocolResolver
(
&
cfg
)
.
Resolve
(
openAIOAuthEnabled
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocket
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v1_enabled"
,
decision
.
Reason
)
})
t
.
Run
(
"透传开关不影响WS协议判定"
,
func
(
t
*
testing
.
T
)
{
account
:=
*
openAIOAuthEnabled
account
.
Extra
=
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_enabled"
:
true
,
"openai_passthrough"
:
true
,
}
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
&
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_enabled"
,
decision
.
Reason
)
})
t
.
Run
(
"账号级强制HTTP"
,
func
(
t
*
testing
.
T
)
{
account
:=
*
openAIOAuthEnabled
account
.
Extra
=
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_enabled"
:
true
,
"openai_ws_force_http"
:
true
,
}
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
&
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"account_force_http"
,
decision
.
Reason
)
})
t
.
Run
(
"全局关闭保持HTTP"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
*
baseCfg
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
false
decision
:=
NewOpenAIWSProtocolResolver
(
&
cfg
)
.
Resolve
(
openAIOAuthEnabled
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"global_disabled"
,
decision
.
Reason
)
})
t
.
Run
(
"账号开关关闭保持HTTP"
,
func
(
t
*
testing
.
T
)
{
account
:=
*
openAIOAuthEnabled
account
.
Extra
=
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_enabled"
:
false
,
}
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
&
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"account_disabled"
,
decision
.
Reason
)
})
t
.
Run
(
"OAuth账号不会读取API Key专用开关"
,
func
(
t
*
testing
.
T
)
{
account
:=
*
openAIOAuthEnabled
account
.
Extra
=
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
}
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
&
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"account_disabled"
,
decision
.
Reason
)
})
t
.
Run
(
"兼容旧键openai_ws_enabled"
,
func
(
t
*
testing
.
T
)
{
account
:=
*
openAIOAuthEnabled
account
.
Extra
=
map
[
string
]
any
{
"openai_ws_enabled"
:
true
,
}
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
&
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_enabled"
,
decision
.
Reason
)
})
t
.
Run
(
"按账号类型开关控制"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
*
baseCfg
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
false
decision
:=
NewOpenAIWSProtocolResolver
(
&
cfg
)
.
Resolve
(
openAIOAuthEnabled
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"oauth_disabled"
,
decision
.
Reason
)
})
t
.
Run
(
"API Key 账号关闭开关时回退HTTP"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
*
baseCfg
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
false
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
&
cfg
)
.
Resolve
(
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"apikey_disabled"
,
decision
.
Reason
)
})
t
.
Run
(
"未知认证类型回退HTTP"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
"unknown_type"
,
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
baseCfg
)
.
Resolve
(
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"unknown_auth_type"
,
decision
.
Reason
)
})
}
func
TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Gateway
.
OpenAIWS
.
Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
OAuthEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
APIKeyEnabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
ResponsesWebsocketsV2
=
true
cfg
.
Gateway
.
OpenAIWS
.
ModeRouterV2Enabled
=
true
cfg
.
Gateway
.
OpenAIWS
.
IngressModeDefault
=
OpenAIWSIngressModeShared
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeDedicated
,
},
}
t
.
Run
(
"dedicated mode routes to ws v2"
,
func
(
t
*
testing
.
T
)
{
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
account
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_dedicated"
,
decision
.
Reason
)
})
t
.
Run
(
"off mode routes to http"
,
func
(
t
*
testing
.
T
)
{
offAccount
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeOff
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
offAccount
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"account_mode_off"
,
decision
.
Reason
)
})
t
.
Run
(
"legacy boolean maps to shared in v2 router"
,
func
(
t
*
testing
.
T
)
{
legacyAccount
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
legacyAccount
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportResponsesWebsocketV2
,
decision
.
Transport
)
require
.
Equal
(
t
,
"ws_v2_mode_shared"
,
decision
.
Reason
)
})
t
.
Run
(
"non-positive concurrency is rejected in v2 router"
,
func
(
t
*
testing
.
T
)
{
invalidConcurrency
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeShared
,
},
}
decision
:=
NewOpenAIWSProtocolResolver
(
cfg
)
.
Resolve
(
invalidConcurrency
)
require
.
Equal
(
t
,
OpenAIUpstreamTransportHTTPSSE
,
decision
.
Transport
)
require
.
Equal
(
t
,
"account_concurrency_invalid"
,
decision
.
Reason
)
})
}
backend/internal/service/openai_ws_state_store.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
)
const
(
openAIWSResponseAccountCachePrefix
=
"openai:response:"
openAIWSStateStoreCleanupInterval
=
time
.
Minute
openAIWSStateStoreCleanupMaxPerMap
=
512
openAIWSStateStoreMaxEntriesPerMap
=
65536
openAIWSStateStoreRedisTimeout
=
3
*
time
.
Second
)
type
openAIWSAccountBinding
struct
{
accountID
int64
expiresAt
time
.
Time
}
type
openAIWSConnBinding
struct
{
connID
string
expiresAt
time
.
Time
}
type
openAIWSTurnStateBinding
struct
{
turnState
string
expiresAt
time
.
Time
}
type
openAIWSSessionConnBinding
struct
{
connID
string
expiresAt
time
.
Time
}
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
//
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
// response_id -> conn_id 仅在本进程内有效。
type
OpenAIWSStateStore
interface
{
BindResponseAccount
(
ctx
context
.
Context
,
groupID
int64
,
responseID
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
GetResponseAccount
(
ctx
context
.
Context
,
groupID
int64
,
responseID
string
)
(
int64
,
error
)
DeleteResponseAccount
(
ctx
context
.
Context
,
groupID
int64
,
responseID
string
)
error
BindResponseConn
(
responseID
,
connID
string
,
ttl
time
.
Duration
)
GetResponseConn
(
responseID
string
)
(
string
,
bool
)
DeleteResponseConn
(
responseID
string
)
BindSessionTurnState
(
groupID
int64
,
sessionHash
,
turnState
string
,
ttl
time
.
Duration
)
GetSessionTurnState
(
groupID
int64
,
sessionHash
string
)
(
string
,
bool
)
DeleteSessionTurnState
(
groupID
int64
,
sessionHash
string
)
BindSessionConn
(
groupID
int64
,
sessionHash
,
connID
string
,
ttl
time
.
Duration
)
GetSessionConn
(
groupID
int64
,
sessionHash
string
)
(
string
,
bool
)
DeleteSessionConn
(
groupID
int64
,
sessionHash
string
)
}
type
defaultOpenAIWSStateStore
struct
{
cache
GatewayCache
responseToAccountMu
sync
.
RWMutex
responseToAccount
map
[
string
]
openAIWSAccountBinding
responseToConnMu
sync
.
RWMutex
responseToConn
map
[
string
]
openAIWSConnBinding
sessionToTurnStateMu
sync
.
RWMutex
sessionToTurnState
map
[
string
]
openAIWSTurnStateBinding
sessionToConnMu
sync
.
RWMutex
sessionToConn
map
[
string
]
openAIWSSessionConnBinding
lastCleanupUnixNano
atomic
.
Int64
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func
NewOpenAIWSStateStore
(
cache
GatewayCache
)
OpenAIWSStateStore
{
store
:=
&
defaultOpenAIWSStateStore
{
cache
:
cache
,
responseToAccount
:
make
(
map
[
string
]
openAIWSAccountBinding
,
256
),
responseToConn
:
make
(
map
[
string
]
openAIWSConnBinding
,
256
),
sessionToTurnState
:
make
(
map
[
string
]
openAIWSTurnStateBinding
,
256
),
sessionToConn
:
make
(
map
[
string
]
openAIWSSessionConnBinding
,
256
),
}
store
.
lastCleanupUnixNano
.
Store
(
time
.
Now
()
.
UnixNano
())
return
store
}
func
(
s
*
defaultOpenAIWSStateStore
)
BindResponseAccount
(
ctx
context
.
Context
,
groupID
int64
,
responseID
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
id
:=
normalizeOpenAIWSResponseID
(
responseID
)
if
id
==
""
||
accountID
<=
0
{
return
nil
}
ttl
=
normalizeOpenAIWSTTL
(
ttl
)
s
.
maybeCleanup
()
expiresAt
:=
time
.
Now
()
.
Add
(
ttl
)
s
.
responseToAccountMu
.
Lock
()
ensureBindingCapacity
(
s
.
responseToAccount
,
id
,
openAIWSStateStoreMaxEntriesPerMap
)
s
.
responseToAccount
[
id
]
=
openAIWSAccountBinding
{
accountID
:
accountID
,
expiresAt
:
expiresAt
}
s
.
responseToAccountMu
.
Unlock
()
if
s
.
cache
==
nil
{
return
nil
}
cacheKey
:=
openAIWSResponseAccountCacheKey
(
id
)
cacheCtx
,
cancel
:=
withOpenAIWSStateStoreRedisTimeout
(
ctx
)
defer
cancel
()
return
s
.
cache
.
SetSessionAccountID
(
cacheCtx
,
groupID
,
cacheKey
,
accountID
,
ttl
)
}
func
(
s
*
defaultOpenAIWSStateStore
)
GetResponseAccount
(
ctx
context
.
Context
,
groupID
int64
,
responseID
string
)
(
int64
,
error
)
{
id
:=
normalizeOpenAIWSResponseID
(
responseID
)
if
id
==
""
{
return
0
,
nil
}
s
.
maybeCleanup
()
now
:=
time
.
Now
()
s
.
responseToAccountMu
.
RLock
()
if
binding
,
ok
:=
s
.
responseToAccount
[
id
];
ok
{
if
now
.
Before
(
binding
.
expiresAt
)
{
accountID
:=
binding
.
accountID
s
.
responseToAccountMu
.
RUnlock
()
return
accountID
,
nil
}
}
s
.
responseToAccountMu
.
RUnlock
()
if
s
.
cache
==
nil
{
return
0
,
nil
}
cacheKey
:=
openAIWSResponseAccountCacheKey
(
id
)
cacheCtx
,
cancel
:=
withOpenAIWSStateStoreRedisTimeout
(
ctx
)
defer
cancel
()
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
cacheCtx
,
groupID
,
cacheKey
)
if
err
!=
nil
||
accountID
<=
0
{
// 缓存读取失败不阻断主流程,按未命中降级。
return
0
,
nil
}
return
accountID
,
nil
}
func
(
s
*
defaultOpenAIWSStateStore
)
DeleteResponseAccount
(
ctx
context
.
Context
,
groupID
int64
,
responseID
string
)
error
{
id
:=
normalizeOpenAIWSResponseID
(
responseID
)
if
id
==
""
{
return
nil
}
s
.
responseToAccountMu
.
Lock
()
delete
(
s
.
responseToAccount
,
id
)
s
.
responseToAccountMu
.
Unlock
()
if
s
.
cache
==
nil
{
return
nil
}
cacheCtx
,
cancel
:=
withOpenAIWSStateStoreRedisTimeout
(
ctx
)
defer
cancel
()
return
s
.
cache
.
DeleteSessionAccountID
(
cacheCtx
,
groupID
,
openAIWSResponseAccountCacheKey
(
id
))
}
func
(
s
*
defaultOpenAIWSStateStore
)
BindResponseConn
(
responseID
,
connID
string
,
ttl
time
.
Duration
)
{
id
:=
normalizeOpenAIWSResponseID
(
responseID
)
conn
:=
strings
.
TrimSpace
(
connID
)
if
id
==
""
||
conn
==
""
{
return
}
ttl
=
normalizeOpenAIWSTTL
(
ttl
)
s
.
maybeCleanup
()
s
.
responseToConnMu
.
Lock
()
ensureBindingCapacity
(
s
.
responseToConn
,
id
,
openAIWSStateStoreMaxEntriesPerMap
)
s
.
responseToConn
[
id
]
=
openAIWSConnBinding
{
connID
:
conn
,
expiresAt
:
time
.
Now
()
.
Add
(
ttl
),
}
s
.
responseToConnMu
.
Unlock
()
}
func
(
s
*
defaultOpenAIWSStateStore
)
GetResponseConn
(
responseID
string
)
(
string
,
bool
)
{
id
:=
normalizeOpenAIWSResponseID
(
responseID
)
if
id
==
""
{
return
""
,
false
}
s
.
maybeCleanup
()
now
:=
time
.
Now
()
s
.
responseToConnMu
.
RLock
()
binding
,
ok
:=
s
.
responseToConn
[
id
]
s
.
responseToConnMu
.
RUnlock
()
if
!
ok
||
now
.
After
(
binding
.
expiresAt
)
||
strings
.
TrimSpace
(
binding
.
connID
)
==
""
{
return
""
,
false
}
return
binding
.
connID
,
true
}
func
(
s
*
defaultOpenAIWSStateStore
)
DeleteResponseConn
(
responseID
string
)
{
id
:=
normalizeOpenAIWSResponseID
(
responseID
)
if
id
==
""
{
return
}
s
.
responseToConnMu
.
Lock
()
delete
(
s
.
responseToConn
,
id
)
s
.
responseToConnMu
.
Unlock
()
}
func
(
s
*
defaultOpenAIWSStateStore
)
BindSessionTurnState
(
groupID
int64
,
sessionHash
,
turnState
string
,
ttl
time
.
Duration
)
{
key
:=
openAIWSSessionTurnStateKey
(
groupID
,
sessionHash
)
state
:=
strings
.
TrimSpace
(
turnState
)
if
key
==
""
||
state
==
""
{
return
}
ttl
=
normalizeOpenAIWSTTL
(
ttl
)
s
.
maybeCleanup
()
s
.
sessionToTurnStateMu
.
Lock
()
ensureBindingCapacity
(
s
.
sessionToTurnState
,
key
,
openAIWSStateStoreMaxEntriesPerMap
)
s
.
sessionToTurnState
[
key
]
=
openAIWSTurnStateBinding
{
turnState
:
state
,
expiresAt
:
time
.
Now
()
.
Add
(
ttl
),
}
s
.
sessionToTurnStateMu
.
Unlock
()
}
func
(
s
*
defaultOpenAIWSStateStore
)
GetSessionTurnState
(
groupID
int64
,
sessionHash
string
)
(
string
,
bool
)
{
key
:=
openAIWSSessionTurnStateKey
(
groupID
,
sessionHash
)
if
key
==
""
{
return
""
,
false
}
s
.
maybeCleanup
()
now
:=
time
.
Now
()
s
.
sessionToTurnStateMu
.
RLock
()
binding
,
ok
:=
s
.
sessionToTurnState
[
key
]
s
.
sessionToTurnStateMu
.
RUnlock
()
if
!
ok
||
now
.
After
(
binding
.
expiresAt
)
||
strings
.
TrimSpace
(
binding
.
turnState
)
==
""
{
return
""
,
false
}
return
binding
.
turnState
,
true
}
func
(
s
*
defaultOpenAIWSStateStore
)
DeleteSessionTurnState
(
groupID
int64
,
sessionHash
string
)
{
key
:=
openAIWSSessionTurnStateKey
(
groupID
,
sessionHash
)
if
key
==
""
{
return
}
s
.
sessionToTurnStateMu
.
Lock
()
delete
(
s
.
sessionToTurnState
,
key
)
s
.
sessionToTurnStateMu
.
Unlock
()
}
func
(
s
*
defaultOpenAIWSStateStore
)
BindSessionConn
(
groupID
int64
,
sessionHash
,
connID
string
,
ttl
time
.
Duration
)
{
key
:=
openAIWSSessionTurnStateKey
(
groupID
,
sessionHash
)
conn
:=
strings
.
TrimSpace
(
connID
)
if
key
==
""
||
conn
==
""
{
return
}
ttl
=
normalizeOpenAIWSTTL
(
ttl
)
s
.
maybeCleanup
()
s
.
sessionToConnMu
.
Lock
()
ensureBindingCapacity
(
s
.
sessionToConn
,
key
,
openAIWSStateStoreMaxEntriesPerMap
)
s
.
sessionToConn
[
key
]
=
openAIWSSessionConnBinding
{
connID
:
conn
,
expiresAt
:
time
.
Now
()
.
Add
(
ttl
),
}
s
.
sessionToConnMu
.
Unlock
()
}
func
(
s
*
defaultOpenAIWSStateStore
)
GetSessionConn
(
groupID
int64
,
sessionHash
string
)
(
string
,
bool
)
{
key
:=
openAIWSSessionTurnStateKey
(
groupID
,
sessionHash
)
if
key
==
""
{
return
""
,
false
}
s
.
maybeCleanup
()
now
:=
time
.
Now
()
s
.
sessionToConnMu
.
RLock
()
binding
,
ok
:=
s
.
sessionToConn
[
key
]
s
.
sessionToConnMu
.
RUnlock
()
if
!
ok
||
now
.
After
(
binding
.
expiresAt
)
||
strings
.
TrimSpace
(
binding
.
connID
)
==
""
{
return
""
,
false
}
return
binding
.
connID
,
true
}
func
(
s
*
defaultOpenAIWSStateStore
)
DeleteSessionConn
(
groupID
int64
,
sessionHash
string
)
{
key
:=
openAIWSSessionTurnStateKey
(
groupID
,
sessionHash
)
if
key
==
""
{
return
}
s
.
sessionToConnMu
.
Lock
()
delete
(
s
.
sessionToConn
,
key
)
s
.
sessionToConnMu
.
Unlock
()
}
func
(
s
*
defaultOpenAIWSStateStore
)
maybeCleanup
()
{
if
s
==
nil
{
return
}
now
:=
time
.
Now
()
last
:=
time
.
Unix
(
0
,
s
.
lastCleanupUnixNano
.
Load
())
if
now
.
Sub
(
last
)
<
openAIWSStateStoreCleanupInterval
{
return
}
if
!
s
.
lastCleanupUnixNano
.
CompareAndSwap
(
last
.
UnixNano
(),
now
.
UnixNano
())
{
return
}
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
s
.
responseToAccountMu
.
Lock
()
cleanupExpiredAccountBindings
(
s
.
responseToAccount
,
now
,
openAIWSStateStoreCleanupMaxPerMap
)
s
.
responseToAccountMu
.
Unlock
()
s
.
responseToConnMu
.
Lock
()
cleanupExpiredConnBindings
(
s
.
responseToConn
,
now
,
openAIWSStateStoreCleanupMaxPerMap
)
s
.
responseToConnMu
.
Unlock
()
s
.
sessionToTurnStateMu
.
Lock
()
cleanupExpiredTurnStateBindings
(
s
.
sessionToTurnState
,
now
,
openAIWSStateStoreCleanupMaxPerMap
)
s
.
sessionToTurnStateMu
.
Unlock
()
s
.
sessionToConnMu
.
Lock
()
cleanupExpiredSessionConnBindings
(
s
.
sessionToConn
,
now
,
openAIWSStateStoreCleanupMaxPerMap
)
s
.
sessionToConnMu
.
Unlock
()
}
func
cleanupExpiredAccountBindings
(
bindings
map
[
string
]
openAIWSAccountBinding
,
now
time
.
Time
,
maxScan
int
)
{
if
len
(
bindings
)
==
0
||
maxScan
<=
0
{
return
}
scanned
:=
0
for
key
,
binding
:=
range
bindings
{
if
now
.
After
(
binding
.
expiresAt
)
{
delete
(
bindings
,
key
)
}
scanned
++
if
scanned
>=
maxScan
{
break
}
}
}
func
cleanupExpiredConnBindings
(
bindings
map
[
string
]
openAIWSConnBinding
,
now
time
.
Time
,
maxScan
int
)
{
if
len
(
bindings
)
==
0
||
maxScan
<=
0
{
return
}
scanned
:=
0
for
key
,
binding
:=
range
bindings
{
if
now
.
After
(
binding
.
expiresAt
)
{
delete
(
bindings
,
key
)
}
scanned
++
if
scanned
>=
maxScan
{
break
}
}
}
func
cleanupExpiredTurnStateBindings
(
bindings
map
[
string
]
openAIWSTurnStateBinding
,
now
time
.
Time
,
maxScan
int
)
{
if
len
(
bindings
)
==
0
||
maxScan
<=
0
{
return
}
scanned
:=
0
for
key
,
binding
:=
range
bindings
{
if
now
.
After
(
binding
.
expiresAt
)
{
delete
(
bindings
,
key
)
}
scanned
++
if
scanned
>=
maxScan
{
break
}
}
}
func
cleanupExpiredSessionConnBindings
(
bindings
map
[
string
]
openAIWSSessionConnBinding
,
now
time
.
Time
,
maxScan
int
)
{
if
len
(
bindings
)
==
0
||
maxScan
<=
0
{
return
}
scanned
:=
0
for
key
,
binding
:=
range
bindings
{
if
now
.
After
(
binding
.
expiresAt
)
{
delete
(
bindings
,
key
)
}
scanned
++
if
scanned
>=
maxScan
{
break
}
}
}
func
ensureBindingCapacity
[
T
any
](
bindings
map
[
string
]
T
,
incomingKey
string
,
maxEntries
int
)
{
if
len
(
bindings
)
<
maxEntries
||
maxEntries
<=
0
{
return
}
if
_
,
exists
:=
bindings
[
incomingKey
];
exists
{
return
}
// 固定上限保护:淘汰任意一项,优先保证内存有界。
for
key
:=
range
bindings
{
delete
(
bindings
,
key
)
return
}
}
func
normalizeOpenAIWSResponseID
(
responseID
string
)
string
{
return
strings
.
TrimSpace
(
responseID
)
}
func
openAIWSResponseAccountCacheKey
(
responseID
string
)
string
{
sum
:=
sha256
.
Sum256
([]
byte
(
responseID
))
return
openAIWSResponseAccountCachePrefix
+
hex
.
EncodeToString
(
sum
[
:
])
}
func
normalizeOpenAIWSTTL
(
ttl
time
.
Duration
)
time
.
Duration
{
if
ttl
<=
0
{
return
time
.
Hour
}
return
ttl
}
func
openAIWSSessionTurnStateKey
(
groupID
int64
,
sessionHash
string
)
string
{
hash
:=
strings
.
TrimSpace
(
sessionHash
)
if
hash
==
""
{
return
""
}
return
fmt
.
Sprintf
(
"%d:%s"
,
groupID
,
hash
)
}
func
withOpenAIWSStateStoreRedisTimeout
(
ctx
context
.
Context
)
(
context
.
Context
,
context
.
CancelFunc
)
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
}
return
context
.
WithTimeout
(
ctx
,
openAIWSStateStoreRedisTimeout
)
}
backend/internal/service/openai_ws_state_store_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func
TestOpenAIWSStateStore_BindGetDeleteResponseAccount
(
t
*
testing
.
T
)
{
cache
:=
&
stubGatewayCache
{}
store
:=
NewOpenAIWSStateStore
(
cache
)
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
require
.
NoError
(
t
,
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_abc"
,
101
,
time
.
Minute
))
accountID
,
err
:=
store
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_abc"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
101
),
accountID
)
require
.
NoError
(
t
,
store
.
DeleteResponseAccount
(
ctx
,
groupID
,
"resp_abc"
))
accountID
,
err
=
store
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_abc"
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
accountID
)
}
func
TestOpenAIWSStateStore_ResponseConnTTL
(
t
*
testing
.
T
)
{
store
:=
NewOpenAIWSStateStore
(
nil
)
store
.
BindResponseConn
(
"resp_conn"
,
"conn_1"
,
30
*
time
.
Millisecond
)
connID
,
ok
:=
store
.
GetResponseConn
(
"resp_conn"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"conn_1"
,
connID
)
time
.
Sleep
(
60
*
time
.
Millisecond
)
_
,
ok
=
store
.
GetResponseConn
(
"resp_conn"
)
require
.
False
(
t
,
ok
)
}
func
TestOpenAIWSStateStore_SessionTurnStateTTL
(
t
*
testing
.
T
)
{
store
:=
NewOpenAIWSStateStore
(
nil
)
store
.
BindSessionTurnState
(
9
,
"session_hash_1"
,
"turn_state_1"
,
30
*
time
.
Millisecond
)
state
,
ok
:=
store
.
GetSessionTurnState
(
9
,
"session_hash_1"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"turn_state_1"
,
state
)
// group 隔离
_
,
ok
=
store
.
GetSessionTurnState
(
10
,
"session_hash_1"
)
require
.
False
(
t
,
ok
)
time
.
Sleep
(
60
*
time
.
Millisecond
)
_
,
ok
=
store
.
GetSessionTurnState
(
9
,
"session_hash_1"
)
require
.
False
(
t
,
ok
)
}
func
TestOpenAIWSStateStore_SessionConnTTL
(
t
*
testing
.
T
)
{
store
:=
NewOpenAIWSStateStore
(
nil
)
store
.
BindSessionConn
(
9
,
"session_hash_conn_1"
,
"conn_1"
,
30
*
time
.
Millisecond
)
connID
,
ok
:=
store
.
GetSessionConn
(
9
,
"session_hash_conn_1"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"conn_1"
,
connID
)
// group 隔离
_
,
ok
=
store
.
GetSessionConn
(
10
,
"session_hash_conn_1"
)
require
.
False
(
t
,
ok
)
time
.
Sleep
(
60
*
time
.
Millisecond
)
_
,
ok
=
store
.
GetSessionConn
(
9
,
"session_hash_conn_1"
)
require
.
False
(
t
,
ok
)
}
func
TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss
(
t
*
testing
.
T
)
{
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{}}
store
:=
NewOpenAIWSStateStore
(
cache
)
ctx
:=
context
.
Background
()
groupID
:=
int64
(
17
)
responseID
:=
"resp_cache_stale"
cacheKey
:=
openAIWSResponseAccountCacheKey
(
responseID
)
cache
.
sessionBindings
[
cacheKey
]
=
501
accountID
,
err
:=
store
.
GetResponseAccount
(
ctx
,
groupID
,
responseID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
501
),
accountID
)
delete
(
cache
.
sessionBindings
,
cacheKey
)
accountID
,
err
=
store
.
GetResponseAccount
(
ctx
,
groupID
,
responseID
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
accountID
,
"上游缓存失效后不应继续命中本地陈旧映射"
)
}
func
TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally
(
t
*
testing
.
T
)
{
raw
:=
NewOpenAIWSStateStore
(
nil
)
store
,
ok
:=
raw
.
(
*
defaultOpenAIWSStateStore
)
require
.
True
(
t
,
ok
)
expiredAt
:=
time
.
Now
()
.
Add
(
-
time
.
Minute
)
total
:=
2048
store
.
responseToConnMu
.
Lock
()
for
i
:=
0
;
i
<
total
;
i
++
{
store
.
responseToConn
[
fmt
.
Sprintf
(
"resp_%d"
,
i
)]
=
openAIWSConnBinding
{
connID
:
"conn_incremental"
,
expiresAt
:
expiredAt
,
}
}
store
.
responseToConnMu
.
Unlock
()
store
.
lastCleanupUnixNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
openAIWSStateStoreCleanupInterval
)
.
UnixNano
())
store
.
maybeCleanup
()
store
.
responseToConnMu
.
RLock
()
remainingAfterFirst
:=
len
(
store
.
responseToConn
)
store
.
responseToConnMu
.
RUnlock
()
require
.
Less
(
t
,
remainingAfterFirst
,
total
,
"单轮 cleanup 应至少有进展"
)
require
.
Greater
(
t
,
remainingAfterFirst
,
0
,
"增量清理不要求单轮清空全部键"
)
for
i
:=
0
;
i
<
8
;
i
++
{
store
.
lastCleanupUnixNano
.
Store
(
time
.
Now
()
.
Add
(
-
2
*
openAIWSStateStoreCleanupInterval
)
.
UnixNano
())
store
.
maybeCleanup
()
}
store
.
responseToConnMu
.
RLock
()
remaining
:=
len
(
store
.
responseToConn
)
store
.
responseToConnMu
.
RUnlock
()
require
.
Zero
(
t
,
remaining
,
"多轮 cleanup 后应逐步清空全部过期键"
)
}
func
TestEnsureBindingCapacity_EvictsOneWhenMapIsFull
(
t
*
testing
.
T
)
{
bindings
:=
map
[
string
]
int
{
"a"
:
1
,
"b"
:
2
,
}
ensureBindingCapacity
(
bindings
,
"c"
,
2
)
bindings
[
"c"
]
=
3
require
.
Len
(
t
,
bindings
,
2
)
require
.
Equal
(
t
,
3
,
bindings
[
"c"
])
}
func
TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey
(
t
*
testing
.
T
)
{
bindings
:=
map
[
string
]
int
{
"a"
:
1
,
"b"
:
2
,
}
ensureBindingCapacity
(
bindings
,
"a"
,
2
)
bindings
[
"a"
]
=
9
require
.
Len
(
t
,
bindings
,
2
)
require
.
Equal
(
t
,
9
,
bindings
[
"a"
])
}
type
openAIWSStateStoreTimeoutProbeCache
struct
{
setHasDeadline
bool
getHasDeadline
bool
deleteHasDeadline
bool
setDeadlineDelta
time
.
Duration
getDeadlineDelta
time
.
Duration
delDeadlineDelta
time
.
Duration
}
func
(
c
*
openAIWSStateStoreTimeoutProbeCache
)
GetSessionAccountID
(
ctx
context
.
Context
,
_
int64
,
_
string
)
(
int64
,
error
)
{
if
deadline
,
ok
:=
ctx
.
Deadline
();
ok
{
c
.
getHasDeadline
=
true
c
.
getDeadlineDelta
=
time
.
Until
(
deadline
)
}
return
123
,
nil
}
func
(
c
*
openAIWSStateStoreTimeoutProbeCache
)
SetSessionAccountID
(
ctx
context
.
Context
,
_
int64
,
_
string
,
_
int64
,
_
time
.
Duration
)
error
{
if
deadline
,
ok
:=
ctx
.
Deadline
();
ok
{
c
.
setHasDeadline
=
true
c
.
setDeadlineDelta
=
time
.
Until
(
deadline
)
}
return
errors
.
New
(
"set failed"
)
}
func
(
c
*
openAIWSStateStoreTimeoutProbeCache
)
RefreshSessionTTL
(
context
.
Context
,
int64
,
string
,
time
.
Duration
)
error
{
return
nil
}
func
(
c
*
openAIWSStateStoreTimeoutProbeCache
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
_
int64
,
_
string
)
error
{
if
deadline
,
ok
:=
ctx
.
Deadline
();
ok
{
c
.
deleteHasDeadline
=
true
c
.
delDeadlineDelta
=
time
.
Until
(
deadline
)
}
return
nil
}
func
TestOpenAIWSStateStore_RedisOpsUseShortTimeout
(
t
*
testing
.
T
)
{
probe
:=
&
openAIWSStateStoreTimeoutProbeCache
{}
store
:=
NewOpenAIWSStateStore
(
probe
)
ctx
:=
context
.
Background
()
groupID
:=
int64
(
5
)
err
:=
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_timeout_probe"
,
11
,
time
.
Minute
)
require
.
Error
(
t
,
err
)
accountID
,
getErr
:=
store
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_timeout_probe"
)
require
.
NoError
(
t
,
getErr
)
require
.
Equal
(
t
,
int64
(
11
),
accountID
,
"本地缓存命中应优先返回已绑定账号"
)
require
.
NoError
(
t
,
store
.
DeleteResponseAccount
(
ctx
,
groupID
,
"resp_timeout_probe"
))
require
.
True
(
t
,
probe
.
setHasDeadline
,
"SetSessionAccountID 应携带独立超时上下文"
)
require
.
True
(
t
,
probe
.
deleteHasDeadline
,
"DeleteSessionAccountID 应携带独立超时上下文"
)
require
.
False
(
t
,
probe
.
getHasDeadline
,
"GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取"
)
require
.
Greater
(
t
,
probe
.
setDeadlineDelta
,
2
*
time
.
Second
)
require
.
LessOrEqual
(
t
,
probe
.
setDeadlineDelta
,
3
*
time
.
Second
)
require
.
Greater
(
t
,
probe
.
delDeadlineDelta
,
2
*
time
.
Second
)
require
.
LessOrEqual
(
t
,
probe
.
delDeadlineDelta
,
3
*
time
.
Second
)
probe2
:=
&
openAIWSStateStoreTimeoutProbeCache
{}
store2
:=
NewOpenAIWSStateStore
(
probe2
)
accountID2
,
err2
:=
store2
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_cache_only"
)
require
.
NoError
(
t
,
err2
)
require
.
Equal
(
t
,
int64
(
123
),
accountID2
)
require
.
True
(
t
,
probe2
.
getHasDeadline
,
"GetSessionAccountID 在缓存未命中时应携带独立超时上下文"
)
require
.
Greater
(
t
,
probe2
.
getDeadlineDelta
,
2
*
time
.
Second
)
require
.
LessOrEqual
(
t
,
probe2
.
getDeadlineDelta
,
3
*
time
.
Second
)
}
func
TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
withOpenAIWSStateStoreRedisTimeout
(
context
.
Background
())
defer
cancel
()
require
.
NotNil
(
t
,
ctx
)
_
,
ok
:=
ctx
.
Deadline
()
require
.
True
(
t
,
ok
,
"应附加短超时"
)
}
backend/internal/service/ops_retry.go
View file @
bb664d9b
...
...
@@ -13,7 +13,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
...
...
@@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
attemptCtx
:=
ctx
if
switches
>
0
{
attemptCtx
=
context
.
WithValue
(
attemptCtx
,
ctxkey
.
AccountSwitchCount
,
switches
)
attemptCtx
=
With
AccountSwitchCount
(
attemptCtx
,
switches
,
false
)
}
exec
:=
func
()
*
opsRetryExecution
{
defer
selection
.
ReleaseFunc
()
...
...
@@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.
}
c
.
Request
=
req
SetOpenAIClientTransport
(
c
,
OpenAIClientTransportHTTP
)
return
c
,
w
}
...
...
backend/internal/service/ops_retry_context_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
func
TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders
(
t
*
testing
.
T
)
{
errorLog
:=
&
OpsErrorLogDetail
{
OpsErrorLog
:
OpsErrorLog
{
RequestPath
:
"/openai/v1/responses"
,
},
UserAgent
:
"ops-retry-agent/1.0"
,
RequestHeaders
:
`{
"anthropic-beta":"beta-v1",
"ANTHROPIC-VERSION":"2023-06-01",
"authorization":"Bearer should-not-forward"
}`
,
}
c
,
w
:=
newOpsRetryContext
(
context
.
Background
(),
errorLog
)
require
.
NotNil
(
t
,
c
)
require
.
NotNil
(
t
,
w
)
require
.
NotNil
(
t
,
c
.
Request
)
require
.
Equal
(
t
,
"/openai/v1/responses"
,
c
.
Request
.
URL
.
Path
)
require
.
Equal
(
t
,
"application/json"
,
c
.
Request
.
Header
.
Get
(
"Content-Type"
))
require
.
Equal
(
t
,
"ops-retry-agent/1.0"
,
c
.
Request
.
Header
.
Get
(
"User-Agent"
))
require
.
Equal
(
t
,
"beta-v1"
,
c
.
Request
.
Header
.
Get
(
"anthropic-beta"
))
require
.
Equal
(
t
,
"2023-06-01"
,
c
.
Request
.
Header
.
Get
(
"anthropic-version"
))
require
.
Empty
(
t
,
c
.
Request
.
Header
.
Get
(
"authorization"
),
"未在白名单内的敏感头不应被重放"
)
require
.
Equal
(
t
,
OpenAIClientTransportHTTP
,
GetOpenAIClientTransport
(
c
))
}
func
TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport
(
t
*
testing
.
T
)
{
errorLog
:=
&
OpsErrorLogDetail
{
RequestHeaders
:
"{invalid-json"
,
}
c
,
_
:=
newOpsRetryContext
(
context
.
Background
(),
errorLog
)
require
.
NotNil
(
t
,
c
)
require
.
NotNil
(
t
,
c
.
Request
)
require
.
Equal
(
t
,
"/"
,
c
.
Request
.
URL
.
Path
)
require
.
Equal
(
t
,
OpenAIClientTransportHTTP
,
GetOpenAIClientTransport
(
c
))
}
backend/internal/service/ops_upstream_context.go
View file @
bb664d9b
...
...
@@ -27,6 +27,11 @@ const (
OpsUpstreamLatencyMsKey
=
"ops_upstream_latency_ms"
OpsResponseLatencyMsKey
=
"ops_response_latency_ms"
OpsTimeToFirstTokenMsKey
=
"ops_time_to_first_token_ms"
// OpenAI WS 关键观测字段
OpsOpenAIWSQueueWaitMsKey
=
"ops_openai_ws_queue_wait_ms"
OpsOpenAIWSConnPickMsKey
=
"ops_openai_ws_conn_pick_ms"
OpsOpenAIWSConnReusedKey
=
"ops_openai_ws_conn_reused"
OpsOpenAIWSConnIDKey
=
"ops_openai_ws_conn_id"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
...
...
backend/internal/service/ratelimit_service.go
View file @
bb664d9b
...
...
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// RateLimitService 处理限流和过载状态管理
...
...
@@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct {
totals
GeminiUsageTotals
}
type
geminiUsageTotalsBatchProvider
interface
{
GetGeminiUsageTotalsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
GeminiUsageTotals
,
error
)
}
const
geminiPrecheckCacheTTL
=
time
.
Minute
// NewRateLimitService 创建RateLimitService实例
...
...
@@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
if
upstreamMsg
!=
""
{
msg
=
"Access forbidden (403): "
+
upstreamMsg
}
logger
.
LegacyPrintf
(
"service.ratelimit"
,
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s"
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
strings
.
TrimSpace
(
headers
.
Get
(
"x-request-id"
)),
strings
.
TrimSpace
(
headers
.
Get
(
"cf-ray"
)),
upstreamMsg
,
truncateForLog
(
responseBody
,
1024
),
)
s
.
handleAuthError
(
ctx
,
account
,
msg
)
shouldDisable
=
true
case
429
:
...
...
@@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start
:=
geminiDailyWindowStart
(
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
if
!
ok
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
,
nil
)
if
err
!=
nil
{
return
true
,
err
}
...
...
@@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if
limit
>
0
{
start
:=
now
.
Truncate
(
time
.
Minute
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
,
nil
)
if
err
!=
nil
{
return
true
,
err
}
...
...
@@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return
true
,
nil
}
// PreCheckUsageBatch performs quota precheck for multiple accounts in one request.
// Returned map value=false means the account should be skipped.
func
(
s
*
RateLimitService
)
PreCheckUsageBatch
(
ctx
context
.
Context
,
accounts
[]
*
Account
,
requestedModel
string
)
(
map
[
int64
]
bool
,
error
)
{
result
:=
make
(
map
[
int64
]
bool
,
len
(
accounts
))
for
_
,
account
:=
range
accounts
{
if
account
==
nil
{
continue
}
result
[
account
.
ID
]
=
true
}
if
len
(
accounts
)
==
0
||
requestedModel
==
""
{
return
result
,
nil
}
if
s
.
usageRepo
==
nil
||
s
.
geminiQuotaService
==
nil
{
return
result
,
nil
}
modelClass
:=
geminiModelClassFromName
(
requestedModel
)
now
:=
time
.
Now
()
dailyStart
:=
geminiDailyWindowStart
(
now
)
minuteStart
:=
now
.
Truncate
(
time
.
Minute
)
type
quotaAccount
struct
{
account
*
Account
quota
GeminiQuota
}
quotaAccounts
:=
make
([]
quotaAccount
,
0
,
len
(
accounts
))
for
_
,
account
:=
range
accounts
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
continue
}
quota
,
ok
:=
s
.
geminiQuotaService
.
QuotaForAccount
(
ctx
,
account
)
if
!
ok
{
continue
}
quotaAccounts
=
append
(
quotaAccounts
,
quotaAccount
{
account
:
account
,
quota
:
quota
,
})
}
if
len
(
quotaAccounts
)
==
0
{
return
result
,
nil
}
// 1) Daily precheck (cached + batch DB fallback)
dailyTotalsByID
:=
make
(
map
[
int64
]
GeminiUsageTotals
,
len
(
quotaAccounts
))
dailyMissIDs
:=
make
([]
int64
,
0
,
len
(
quotaAccounts
))
for
_
,
item
:=
range
quotaAccounts
{
limit
:=
geminiDailyLimit
(
item
.
quota
,
modelClass
)
if
limit
<=
0
{
continue
}
accountID
:=
item
.
account
.
ID
if
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
accountID
,
dailyStart
,
now
);
ok
{
dailyTotalsByID
[
accountID
]
=
totals
continue
}
dailyMissIDs
=
append
(
dailyMissIDs
,
accountID
)
}
if
len
(
dailyMissIDs
)
>
0
{
totalsBatch
,
err
:=
s
.
getGeminiUsageTotalsBatch
(
ctx
,
dailyMissIDs
,
dailyStart
,
now
)
if
err
!=
nil
{
return
result
,
err
}
for
_
,
accountID
:=
range
dailyMissIDs
{
totals
:=
totalsBatch
[
accountID
]
dailyTotalsByID
[
accountID
]
=
totals
s
.
setGeminiUsageTotals
(
accountID
,
dailyStart
,
now
,
totals
)
}
}
for
_
,
item
:=
range
quotaAccounts
{
limit
:=
geminiDailyLimit
(
item
.
quota
,
modelClass
)
if
limit
<=
0
{
continue
}
accountID
:=
item
.
account
.
ID
used
:=
geminiUsedRequests
(
item
.
quota
,
modelClass
,
dailyTotalsByID
[
accountID
],
true
)
if
used
>=
limit
{
resetAt
:=
geminiDailyResetTime
(
now
)
slog
.
Info
(
"gemini_precheck_daily_quota_reached_batch"
,
"account_id"
,
accountID
,
"used"
,
used
,
"limit"
,
limit
,
"reset_at"
,
resetAt
)
result
[
accountID
]
=
false
}
}
// 2) Minute precheck (batch DB)
minuteIDs
:=
make
([]
int64
,
0
,
len
(
quotaAccounts
))
for
_
,
item
:=
range
quotaAccounts
{
accountID
:=
item
.
account
.
ID
if
!
result
[
accountID
]
{
continue
}
if
geminiMinuteLimit
(
item
.
quota
,
modelClass
)
<=
0
{
continue
}
minuteIDs
=
append
(
minuteIDs
,
accountID
)
}
if
len
(
minuteIDs
)
==
0
{
return
result
,
nil
}
minuteTotalsByID
,
err
:=
s
.
getGeminiUsageTotalsBatch
(
ctx
,
minuteIDs
,
minuteStart
,
now
)
if
err
!=
nil
{
return
result
,
err
}
for
_
,
item
:=
range
quotaAccounts
{
accountID
:=
item
.
account
.
ID
if
!
result
[
accountID
]
{
continue
}
limit
:=
geminiMinuteLimit
(
item
.
quota
,
modelClass
)
if
limit
<=
0
{
continue
}
used
:=
geminiUsedRequests
(
item
.
quota
,
modelClass
,
minuteTotalsByID
[
accountID
],
false
)
if
used
>=
limit
{
resetAt
:=
minuteStart
.
Add
(
time
.
Minute
)
slog
.
Info
(
"gemini_precheck_minute_quota_reached_batch"
,
"account_id"
,
accountID
,
"used"
,
used
,
"limit"
,
limit
,
"reset_at"
,
resetAt
)
result
[
accountID
]
=
false
}
}
return
result
,
nil
}
func
(
s
*
RateLimitService
)
getGeminiUsageTotalsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
start
,
end
time
.
Time
)
(
map
[
int64
]
GeminiUsageTotals
,
error
)
{
result
:=
make
(
map
[
int64
]
GeminiUsageTotals
,
len
(
accountIDs
))
if
len
(
accountIDs
)
==
0
{
return
result
,
nil
}
ids
:=
make
([]
int64
,
0
,
len
(
accountIDs
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
if
accountID
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
accountID
];
ok
{
continue
}
seen
[
accountID
]
=
struct
{}{}
ids
=
append
(
ids
,
accountID
)
}
if
len
(
ids
)
==
0
{
return
result
,
nil
}
if
batchReader
,
ok
:=
s
.
usageRepo
.
(
geminiUsageTotalsBatchProvider
);
ok
{
stats
,
err
:=
batchReader
.
GetGeminiUsageTotalsBatch
(
ctx
,
ids
,
start
,
end
)
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
accountID
:=
range
ids
{
result
[
accountID
]
=
stats
[
accountID
]
}
return
result
,
nil
}
for
_
,
accountID
:=
range
ids
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
end
,
0
,
0
,
accountID
,
0
,
nil
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
err
}
result
[
accountID
]
=
geminiAggregateUsage
(
stats
)
}
return
result
,
nil
}
func
geminiDailyLimit
(
quota
GeminiQuota
,
modelClass
geminiModelClass
)
int64
{
if
quota
.
SharedRPD
>
0
{
return
quota
.
SharedRPD
}
switch
modelClass
{
case
geminiModelFlash
:
return
quota
.
FlashRPD
default
:
return
quota
.
ProRPD
}
}
func
geminiMinuteLimit
(
quota
GeminiQuota
,
modelClass
geminiModelClass
)
int64
{
if
quota
.
SharedRPM
>
0
{
return
quota
.
SharedRPM
}
switch
modelClass
{
case
geminiModelFlash
:
return
quota
.
FlashRPM
default
:
return
quota
.
ProRPM
}
}
func
geminiUsedRequests
(
quota
GeminiQuota
,
modelClass
geminiModelClass
,
totals
GeminiUsageTotals
,
daily
bool
)
int64
{
if
daily
{
if
quota
.
SharedRPD
>
0
{
return
totals
.
ProRequests
+
totals
.
FlashRequests
}
}
else
{
if
quota
.
SharedRPM
>
0
{
return
totals
.
ProRequests
+
totals
.
FlashRequests
}
}
switch
modelClass
{
case
geminiModelFlash
:
return
totals
.
FlashRequests
default
:
return
totals
.
ProRequests
}
}
func
(
s
*
RateLimitService
)
getGeminiUsageTotals
(
accountID
int64
,
windowStart
,
now
time
.
Time
)
(
GeminiUsageTotals
,
bool
)
{
s
.
usageCacheMu
.
RLock
()
defer
s
.
usageCacheMu
.
RUnlock
()
...
...
backend/internal/service/request_metadata.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
type
requestMetadataContextKey
struct
{}
var
requestMetadataKey
=
requestMetadataContextKey
{}
type
RequestMetadata
struct
{
IsMaxTokensOneHaikuRequest
*
bool
ThinkingEnabled
*
bool
PrefetchedStickyAccountID
*
int64
PrefetchedStickyGroupID
*
int64
SingleAccountRetry
*
bool
AccountSwitchCount
*
int
}
var
(
requestMetadataFallbackIsMaxTokensOneHaikuTotal
atomic
.
Int64
requestMetadataFallbackThinkingEnabledTotal
atomic
.
Int64
requestMetadataFallbackPrefetchedStickyAccount
atomic
.
Int64
requestMetadataFallbackPrefetchedStickyGroup
atomic
.
Int64
requestMetadataFallbackSingleAccountRetryTotal
atomic
.
Int64
requestMetadataFallbackAccountSwitchCountTotal
atomic
.
Int64
)
func
RequestMetadataFallbackStats
()
(
isMaxTokensOneHaiku
,
thinkingEnabled
,
prefetchedStickyAccount
,
prefetchedStickyGroup
,
singleAccountRetry
,
accountSwitchCount
int64
)
{
return
requestMetadataFallbackIsMaxTokensOneHaikuTotal
.
Load
(),
requestMetadataFallbackThinkingEnabledTotal
.
Load
(),
requestMetadataFallbackPrefetchedStickyAccount
.
Load
(),
requestMetadataFallbackPrefetchedStickyGroup
.
Load
(),
requestMetadataFallbackSingleAccountRetryTotal
.
Load
(),
requestMetadataFallbackAccountSwitchCountTotal
.
Load
()
}
func
metadataFromContext
(
ctx
context
.
Context
)
*
RequestMetadata
{
if
ctx
==
nil
{
return
nil
}
md
,
_
:=
ctx
.
Value
(
requestMetadataKey
)
.
(
*
RequestMetadata
)
return
md
}
func
updateRequestMetadata
(
ctx
context
.
Context
,
bridgeOldKeys
bool
,
update
func
(
md
*
RequestMetadata
),
legacyBridge
func
(
ctx
context
.
Context
)
context
.
Context
,
)
context
.
Context
{
if
ctx
==
nil
{
return
nil
}
current
:=
metadataFromContext
(
ctx
)
next
:=
&
RequestMetadata
{}
if
current
!=
nil
{
*
next
=
*
current
}
update
(
next
)
ctx
=
context
.
WithValue
(
ctx
,
requestMetadataKey
,
next
)
if
bridgeOldKeys
&&
legacyBridge
!=
nil
{
ctx
=
legacyBridge
(
ctx
)
}
return
ctx
}
func
WithIsMaxTokensOneHaikuRequest
(
ctx
context
.
Context
,
value
bool
,
bridgeOldKeys
bool
)
context
.
Context
{
return
updateRequestMetadata
(
ctx
,
bridgeOldKeys
,
func
(
md
*
RequestMetadata
)
{
v
:=
value
md
.
IsMaxTokensOneHaikuRequest
=
&
v
},
func
(
base
context
.
Context
)
context
.
Context
{
return
context
.
WithValue
(
base
,
ctxkey
.
IsMaxTokensOneHaikuRequest
,
value
)
})
}
func
WithThinkingEnabled
(
ctx
context
.
Context
,
value
bool
,
bridgeOldKeys
bool
)
context
.
Context
{
return
updateRequestMetadata
(
ctx
,
bridgeOldKeys
,
func
(
md
*
RequestMetadata
)
{
v
:=
value
md
.
ThinkingEnabled
=
&
v
},
func
(
base
context
.
Context
)
context
.
Context
{
return
context
.
WithValue
(
base
,
ctxkey
.
ThinkingEnabled
,
value
)
})
}
func
WithPrefetchedStickySession
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
bridgeOldKeys
bool
)
context
.
Context
{
return
updateRequestMetadata
(
ctx
,
bridgeOldKeys
,
func
(
md
*
RequestMetadata
)
{
account
:=
accountID
group
:=
groupID
md
.
PrefetchedStickyAccountID
=
&
account
md
.
PrefetchedStickyGroupID
=
&
group
},
func
(
base
context
.
Context
)
context
.
Context
{
bridged
:=
context
.
WithValue
(
base
,
ctxkey
.
PrefetchedStickyAccountID
,
accountID
)
return
context
.
WithValue
(
bridged
,
ctxkey
.
PrefetchedStickyGroupID
,
groupID
)
})
}
func
WithSingleAccountRetry
(
ctx
context
.
Context
,
value
bool
,
bridgeOldKeys
bool
)
context
.
Context
{
return
updateRequestMetadata
(
ctx
,
bridgeOldKeys
,
func
(
md
*
RequestMetadata
)
{
v
:=
value
md
.
SingleAccountRetry
=
&
v
},
func
(
base
context
.
Context
)
context
.
Context
{
return
context
.
WithValue
(
base
,
ctxkey
.
SingleAccountRetry
,
value
)
})
}
func
WithAccountSwitchCount
(
ctx
context
.
Context
,
value
int
,
bridgeOldKeys
bool
)
context
.
Context
{
return
updateRequestMetadata
(
ctx
,
bridgeOldKeys
,
func
(
md
*
RequestMetadata
)
{
v
:=
value
md
.
AccountSwitchCount
=
&
v
},
func
(
base
context
.
Context
)
context
.
Context
{
return
context
.
WithValue
(
base
,
ctxkey
.
AccountSwitchCount
,
value
)
})
}
func
IsMaxTokensOneHaikuRequestFromContext
(
ctx
context
.
Context
)
(
bool
,
bool
)
{
if
md
:=
metadataFromContext
(
ctx
);
md
!=
nil
&&
md
.
IsMaxTokensOneHaikuRequest
!=
nil
{
return
*
md
.
IsMaxTokensOneHaikuRequest
,
true
}
if
ctx
==
nil
{
return
false
,
false
}
if
value
,
ok
:=
ctx
.
Value
(
ctxkey
.
IsMaxTokensOneHaikuRequest
)
.
(
bool
);
ok
{
requestMetadataFallbackIsMaxTokensOneHaikuTotal
.
Add
(
1
)
return
value
,
true
}
return
false
,
false
}
func
ThinkingEnabledFromContext
(
ctx
context
.
Context
)
(
bool
,
bool
)
{
if
md
:=
metadataFromContext
(
ctx
);
md
!=
nil
&&
md
.
ThinkingEnabled
!=
nil
{
return
*
md
.
ThinkingEnabled
,
true
}
if
ctx
==
nil
{
return
false
,
false
}
if
value
,
ok
:=
ctx
.
Value
(
ctxkey
.
ThinkingEnabled
)
.
(
bool
);
ok
{
requestMetadataFallbackThinkingEnabledTotal
.
Add
(
1
)
return
value
,
true
}
return
false
,
false
}
func
PrefetchedStickyGroupIDFromContext
(
ctx
context
.
Context
)
(
int64
,
bool
)
{
if
md
:=
metadataFromContext
(
ctx
);
md
!=
nil
&&
md
.
PrefetchedStickyGroupID
!=
nil
{
return
*
md
.
PrefetchedStickyGroupID
,
true
}
if
ctx
==
nil
{
return
0
,
false
}
v
:=
ctx
.
Value
(
ctxkey
.
PrefetchedStickyGroupID
)
switch
t
:=
v
.
(
type
)
{
case
int64
:
requestMetadataFallbackPrefetchedStickyGroup
.
Add
(
1
)
return
t
,
true
case
int
:
requestMetadataFallbackPrefetchedStickyGroup
.
Add
(
1
)
return
int64
(
t
),
true
}
return
0
,
false
}
func
PrefetchedStickyAccountIDFromContext
(
ctx
context
.
Context
)
(
int64
,
bool
)
{
if
md
:=
metadataFromContext
(
ctx
);
md
!=
nil
&&
md
.
PrefetchedStickyAccountID
!=
nil
{
return
*
md
.
PrefetchedStickyAccountID
,
true
}
if
ctx
==
nil
{
return
0
,
false
}
v
:=
ctx
.
Value
(
ctxkey
.
PrefetchedStickyAccountID
)
switch
t
:=
v
.
(
type
)
{
case
int64
:
requestMetadataFallbackPrefetchedStickyAccount
.
Add
(
1
)
return
t
,
true
case
int
:
requestMetadataFallbackPrefetchedStickyAccount
.
Add
(
1
)
return
int64
(
t
),
true
}
return
0
,
false
}
func
SingleAccountRetryFromContext
(
ctx
context
.
Context
)
(
bool
,
bool
)
{
if
md
:=
metadataFromContext
(
ctx
);
md
!=
nil
&&
md
.
SingleAccountRetry
!=
nil
{
return
*
md
.
SingleAccountRetry
,
true
}
if
ctx
==
nil
{
return
false
,
false
}
if
value
,
ok
:=
ctx
.
Value
(
ctxkey
.
SingleAccountRetry
)
.
(
bool
);
ok
{
requestMetadataFallbackSingleAccountRetryTotal
.
Add
(
1
)
return
value
,
true
}
return
false
,
false
}
func
AccountSwitchCountFromContext
(
ctx
context
.
Context
)
(
int
,
bool
)
{
if
md
:=
metadataFromContext
(
ctx
);
md
!=
nil
&&
md
.
AccountSwitchCount
!=
nil
{
return
*
md
.
AccountSwitchCount
,
true
}
if
ctx
==
nil
{
return
0
,
false
}
v
:=
ctx
.
Value
(
ctxkey
.
AccountSwitchCount
)
switch
t
:=
v
.
(
type
)
{
case
int
:
requestMetadataFallbackAccountSwitchCountTotal
.
Add
(
1
)
return
t
,
true
case
int64
:
requestMetadataFallbackAccountSwitchCountTotal
.
Add
(
1
)
return
int
(
t
),
true
}
return
0
,
false
}
backend/internal/service/request_metadata_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func
TestRequestMetadataWriteAndRead_NoBridge
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
=
WithIsMaxTokensOneHaikuRequest
(
ctx
,
true
,
false
)
ctx
=
WithThinkingEnabled
(
ctx
,
true
,
false
)
ctx
=
WithPrefetchedStickySession
(
ctx
,
123
,
456
,
false
)
ctx
=
WithSingleAccountRetry
(
ctx
,
true
,
false
)
ctx
=
WithAccountSwitchCount
(
ctx
,
2
,
false
)
isHaiku
,
ok
:=
IsMaxTokensOneHaikuRequestFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
isHaiku
)
thinking
,
ok
:=
ThinkingEnabledFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
thinking
)
accountID
,
ok
:=
PrefetchedStickyAccountIDFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
123
),
accountID
)
groupID
,
ok
:=
PrefetchedStickyGroupIDFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
456
),
groupID
)
singleRetry
,
ok
:=
SingleAccountRetryFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
singleRetry
)
switchCount
,
ok
:=
AccountSwitchCountFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
2
,
switchCount
)
require
.
Nil
(
t
,
ctx
.
Value
(
ctxkey
.
IsMaxTokensOneHaikuRequest
))
require
.
Nil
(
t
,
ctx
.
Value
(
ctxkey
.
ThinkingEnabled
))
require
.
Nil
(
t
,
ctx
.
Value
(
ctxkey
.
PrefetchedStickyAccountID
))
require
.
Nil
(
t
,
ctx
.
Value
(
ctxkey
.
PrefetchedStickyGroupID
))
require
.
Nil
(
t
,
ctx
.
Value
(
ctxkey
.
SingleAccountRetry
))
require
.
Nil
(
t
,
ctx
.
Value
(
ctxkey
.
AccountSwitchCount
))
}
func
TestRequestMetadataWrite_BridgeLegacyKeys
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
=
WithIsMaxTokensOneHaikuRequest
(
ctx
,
true
,
true
)
ctx
=
WithThinkingEnabled
(
ctx
,
true
,
true
)
ctx
=
WithPrefetchedStickySession
(
ctx
,
123
,
456
,
true
)
ctx
=
WithSingleAccountRetry
(
ctx
,
true
,
true
)
ctx
=
WithAccountSwitchCount
(
ctx
,
2
,
true
)
require
.
Equal
(
t
,
true
,
ctx
.
Value
(
ctxkey
.
IsMaxTokensOneHaikuRequest
))
require
.
Equal
(
t
,
true
,
ctx
.
Value
(
ctxkey
.
ThinkingEnabled
))
require
.
Equal
(
t
,
int64
(
123
),
ctx
.
Value
(
ctxkey
.
PrefetchedStickyAccountID
))
require
.
Equal
(
t
,
int64
(
456
),
ctx
.
Value
(
ctxkey
.
PrefetchedStickyGroupID
))
require
.
Equal
(
t
,
true
,
ctx
.
Value
(
ctxkey
.
SingleAccountRetry
))
require
.
Equal
(
t
,
2
,
ctx
.
Value
(
ctxkey
.
AccountSwitchCount
))
}
func
TestRequestMetadataRead_LegacyFallbackAndStats
(
t
*
testing
.
T
)
{
beforeHaiku
,
beforeThinking
,
beforeAccount
,
beforeGroup
,
beforeSingleRetry
,
beforeSwitchCount
:=
RequestMetadataFallbackStats
()
ctx
:=
context
.
Background
()
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
IsMaxTokensOneHaikuRequest
,
true
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ThinkingEnabled
,
true
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
321
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
654
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
SingleAccountRetry
,
true
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
AccountSwitchCount
,
int64
(
3
))
isHaiku
,
ok
:=
IsMaxTokensOneHaikuRequestFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
isHaiku
)
thinking
,
ok
:=
ThinkingEnabledFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
thinking
)
accountID
,
ok
:=
PrefetchedStickyAccountIDFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
321
),
accountID
)
groupID
,
ok
:=
PrefetchedStickyGroupIDFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
654
),
groupID
)
singleRetry
,
ok
:=
SingleAccountRetryFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
singleRetry
)
switchCount
,
ok
:=
AccountSwitchCountFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
3
,
switchCount
)
afterHaiku
,
afterThinking
,
afterAccount
,
afterGroup
,
afterSingleRetry
,
afterSwitchCount
:=
RequestMetadataFallbackStats
()
require
.
Equal
(
t
,
beforeHaiku
+
1
,
afterHaiku
)
require
.
Equal
(
t
,
beforeThinking
+
1
,
afterThinking
)
require
.
Equal
(
t
,
beforeAccount
+
1
,
afterAccount
)
require
.
Equal
(
t
,
beforeGroup
+
1
,
afterGroup
)
require
.
Equal
(
t
,
beforeSingleRetry
+
1
,
afterSingleRetry
)
require
.
Equal
(
t
,
beforeSwitchCount
+
1
,
afterSwitchCount
)
}
func
TestRequestMetadataRead_PreferMetadataOverLegacy
(
t
*
testing
.
T
)
{
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
ThinkingEnabled
,
false
)
ctx
=
WithThinkingEnabled
(
ctx
,
true
,
false
)
thinking
,
ok
:=
ThinkingEnabledFromContext
(
ctx
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
thinking
)
require
.
Equal
(
t
,
false
,
ctx
.
Value
(
ctxkey
.
ThinkingEnabled
))
}
backend/internal/service/response_header_filter.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
)
func
compileResponseHeaderFilter
(
cfg
*
config
.
Config
)
*
responseheaders
.
CompiledHeaderFilter
{
if
cfg
==
nil
{
return
nil
}
return
responseheaders
.
CompileHeaderFilter
(
cfg
.
Security
.
ResponseHeaders
)
}
backend/internal/service/scheduler_snapshot_service.go
View file @
bb664d9b
...
...
@@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
if
payload
==
nil
{
return
nil
}
ids
:=
parseInt64Slice
(
payload
[
"account_ids"
])
if
s
.
accountRepo
==
nil
{
return
nil
}
rawIDs
:=
parseInt64Slice
(
payload
[
"account_ids"
])
if
len
(
rawIDs
)
==
0
{
return
nil
}
ids
:=
make
([]
int64
,
0
,
len
(
rawIDs
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
rawIDs
))
for
_
,
id
:=
range
rawIDs
{
if
id
<=
0
{
continue
}
if
_
,
exists
:=
seen
[
id
];
exists
{
continue
}
seen
[
id
]
=
struct
{}{}
ids
=
append
(
ids
,
id
)
}
if
len
(
ids
)
==
0
{
return
nil
}
preloadGroupIDs
:=
parseInt64Slice
(
payload
[
"group_ids"
])
accounts
,
err
:=
s
.
accountRepo
.
GetByIDs
(
ctx
,
ids
)
if
err
!=
nil
{
return
err
}
found
:=
make
(
map
[
int64
]
struct
{},
len
(
accounts
))
rebuildGroupSet
:=
make
(
map
[
int64
]
struct
{},
len
(
preloadGroupIDs
))
for
_
,
gid
:=
range
preloadGroupIDs
{
if
gid
>
0
{
rebuildGroupSet
[
gid
]
=
struct
{}{}
}
}
for
_
,
account
:=
range
accounts
{
if
account
==
nil
||
account
.
ID
<=
0
{
continue
}
found
[
account
.
ID
]
=
struct
{}{}
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
return
err
}
}
for
_
,
gid
:=
range
account
.
GroupIDs
{
if
gid
>
0
{
rebuildGroupSet
[
gid
]
=
struct
{}{}
}
}
}
if
s
.
cache
!=
nil
{
for
_
,
id
:=
range
ids
{
if
err
:=
s
.
handleAccountEvent
(
ctx
,
&
id
,
payload
);
err
!=
nil
{
if
_
,
ok
:=
found
[
id
];
ok
{
continue
}
if
err
:=
s
.
cache
.
DeleteAccount
(
ctx
,
id
);
err
!=
nil
{
return
err
}
}
return
nil
}
rebuildGroupIDs
:=
make
([]
int64
,
0
,
len
(
rebuildGroupSet
))
for
gid
:=
range
rebuildGroupSet
{
rebuildGroupIDs
=
append
(
rebuildGroupIDs
,
gid
)
}
return
s
.
rebuildByGroupIDs
(
ctx
,
rebuildGroupIDs
,
"account_bulk_change"
)
}
func
(
s
*
SchedulerSnapshotService
)
handleAccountEvent
(
ctx
context
.
Context
,
accountID
*
int64
,
payload
map
[
string
]
any
)
error
{
...
...
backend/internal/service/setting_service.go
View file @
bb664d9b
...
...
@@ -9,6 +9,7 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...
...
@@ -17,6 +18,8 @@ import (
var
(
ErrRegistrationDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrSettingNotFound
=
infraerrors
.
NotFound
(
"SETTING_NOT_FOUND"
,
"setting not found"
)
ErrSoraS3ProfileNotFound
=
infraerrors
.
NotFound
(
"SORA_S3_PROFILE_NOT_FOUND"
,
"sora s3 profile not found"
)
ErrSoraS3ProfileExists
=
infraerrors
.
Conflict
(
"SORA_S3_PROFILE_EXISTS"
,
"sora s3 profile already exists"
)
)
type
SettingRepository
interface
{
...
...
@@ -34,6 +37,7 @@ type SettingService struct {
settingRepo
SettingRepository
cfg
*
config
.
Config
onUpdate
func
()
// Callback when settings are updated (for cache invalidation)
onS3Update
func
()
// Callback when Sora S3 settings are updated
version
string
// Application version
}
...
...
@@ -76,6 +80,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton
,
SettingKeyPurchaseSubscriptionEnabled
,
SettingKeyPurchaseSubscriptionURL
,
SettingKeySoraClientEnabled
,
SettingKeyLinuxDoConnectEnabled
,
}
...
...
@@ -114,6 +119,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
},
nil
}
...
...
@@ -124,6 +130,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s
.
onUpdate
=
callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func
(
s
*
SettingService
)
SetOnS3UpdateCallback
(
callback
func
())
{
s
.
onS3Update
=
callback
}
// SetVersion sets the application version for injection into public settings
func
(
s
*
SettingService
)
SetVersion
(
version
string
)
{
s
.
version
=
version
...
...
@@ -157,6 +168,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url,omitempty"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version,omitempty"`
}{
...
...
@@ -178,6 +190,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
SoraClientEnabled
:
settings
.
SoraClientEnabled
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
s
.
version
,
},
nil
...
...
@@ -232,6 +245,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
updates
[
SettingKeyPurchaseSubscriptionEnabled
]
=
strconv
.
FormatBool
(
settings
.
PurchaseSubscriptionEnabled
)
updates
[
SettingKeyPurchaseSubscriptionURL
]
=
strings
.
TrimSpace
(
settings
.
PurchaseSubscriptionURL
)
updates
[
SettingKeySoraClientEnabled
]
=
strconv
.
FormatBool
(
settings
.
SoraClientEnabled
)
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
...
...
@@ -383,6 +397,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeySoraClientEnabled
:
"false"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeySMTPPort
:
"587"
,
...
...
@@ -436,6 +451,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
}
// 解析整数类型
...
...
@@ -854,3 +870,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyStreamTimeoutSettings
,
string
(
data
))
}
type
soraS3ProfilesStore
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
soraS3ProfileStoreItem
`json:"items"`
}
type
soraS3ProfileStoreItem
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func
(
s
*
SettingService
)
GetSoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
profiles
,
err
:=
s
.
ListSoraS3Profiles
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
activeProfile
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
activeProfile
==
nil
{
return
&
SoraS3Settings
{},
nil
}
return
&
SoraS3Settings
{
Enabled
:
activeProfile
.
Enabled
,
Endpoint
:
activeProfile
.
Endpoint
,
Region
:
activeProfile
.
Region
,
Bucket
:
activeProfile
.
Bucket
,
AccessKeyID
:
activeProfile
.
AccessKeyID
,
SecretAccessKey
:
activeProfile
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
activeProfile
.
SecretAccessKeyConfigured
,
Prefix
:
activeProfile
.
Prefix
,
ForcePathStyle
:
activeProfile
.
ForcePathStyle
,
CDNURL
:
activeProfile
.
CDNURL
,
DefaultStorageQuotaBytes
:
activeProfile
.
DefaultStorageQuotaBytes
,
},
nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func
(
s
*
SettingService
)
SetSoraS3Settings
(
ctx
context
.
Context
,
settings
*
SoraS3Settings
)
error
{
if
settings
==
nil
{
return
fmt
.
Errorf
(
"settings cannot be nil"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
activeIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
store
.
ActiveProfileID
)
if
activeIndex
<
0
{
activeID
:=
"default"
if
hasSoraS3ProfileID
(
store
.
Items
,
activeID
)
{
activeID
=
fmt
.
Sprintf
(
"default-%d"
,
time
.
Now
()
.
Unix
())
}
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
activeID
,
Name
:
"Default"
,
UpdatedAt
:
now
,
})
store
.
ActiveProfileID
=
activeID
activeIndex
=
len
(
store
.
Items
)
-
1
}
active
:=
store
.
Items
[
activeIndex
]
active
.
Enabled
=
settings
.
Enabled
active
.
Endpoint
=
strings
.
TrimSpace
(
settings
.
Endpoint
)
active
.
Region
=
strings
.
TrimSpace
(
settings
.
Region
)
active
.
Bucket
=
strings
.
TrimSpace
(
settings
.
Bucket
)
active
.
AccessKeyID
=
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
active
.
Prefix
=
strings
.
TrimSpace
(
settings
.
Prefix
)
active
.
ForcePathStyle
=
settings
.
ForcePathStyle
active
.
CDNURL
=
strings
.
TrimSpace
(
settings
.
CDNURL
)
active
.
DefaultStorageQuotaBytes
=
maxInt64
(
settings
.
DefaultStorageQuotaBytes
,
0
)
if
settings
.
SecretAccessKey
!=
""
{
active
.
SecretAccessKey
=
settings
.
SecretAccessKey
}
active
.
UpdatedAt
=
now
store
.
Items
[
activeIndex
]
=
active
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func
(
s
*
SettingService
)
ListSoraS3Profiles
(
ctx
context
.
Context
)
(
*
SoraS3ProfileList
,
error
)
{
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
convertSoraS3ProfilesStore
(
store
),
nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func
(
s
*
SettingService
)
CreateSoraS3Profile
(
ctx
context
.
Context
,
profile
*
SoraS3Profile
,
setActive
bool
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
profileID
:=
strings
.
TrimSpace
(
profile
.
ProfileID
)
if
profileID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
if
hasSoraS3ProfileID
(
store
.
Items
,
profileID
)
{
return
nil
,
ErrSoraS3ProfileExists
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
profileID
,
Name
:
name
,
Enabled
:
profile
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
profile
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
profile
.
Region
),
Bucket
:
strings
.
TrimSpace
(
profile
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
profile
.
AccessKeyID
),
SecretAccessKey
:
profile
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
profile
.
Prefix
),
ForcePathStyle
:
profile
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
profile
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
})
if
setActive
||
store
.
ActiveProfileID
==
""
{
store
.
ActiveProfileID
=
profileID
}
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
created
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
profileID
)
if
created
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
created
,
nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func
(
s
*
SettingService
)
UpdateSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
,
profile
*
SoraS3Profile
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
target
:=
store
.
Items
[
targetIndex
]
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
target
.
Name
=
name
target
.
Enabled
=
profile
.
Enabled
target
.
Endpoint
=
strings
.
TrimSpace
(
profile
.
Endpoint
)
target
.
Region
=
strings
.
TrimSpace
(
profile
.
Region
)
target
.
Bucket
=
strings
.
TrimSpace
(
profile
.
Bucket
)
target
.
AccessKeyID
=
strings
.
TrimSpace
(
profile
.
AccessKeyID
)
target
.
Prefix
=
strings
.
TrimSpace
(
profile
.
Prefix
)
target
.
ForcePathStyle
=
profile
.
ForcePathStyle
target
.
CDNURL
=
strings
.
TrimSpace
(
profile
.
CDNURL
)
target
.
DefaultStorageQuotaBytes
=
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
)
if
profile
.
SecretAccessKey
!=
""
{
target
.
SecretAccessKey
=
profile
.
SecretAccessKey
}
target
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
[
targetIndex
]
=
target
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
updated
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
targetID
)
if
updated
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
updated
,
nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func
(
s
*
SettingService
)
DeleteSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
error
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
ErrSoraS3ProfileNotFound
}
store
.
Items
=
append
(
store
.
Items
[
:
targetIndex
],
store
.
Items
[
targetIndex
+
1
:
]
...
)
if
store
.
ActiveProfileID
==
targetID
{
store
.
ActiveProfileID
=
""
if
len
(
store
.
Items
)
>
0
{
store
.
ActiveProfileID
=
store
.
Items
[
0
]
.
ProfileID
}
}
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func
(
s
*
SettingService
)
SetActiveSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
(
*
SoraS3Profile
,
error
)
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
store
.
ActiveProfileID
=
targetID
store
.
Items
[
targetIndex
]
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
active
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
active
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
active
,
nil
}
func
(
s
*
SettingService
)
loadSoraS3ProfilesStore
(
ctx
context
.
Context
)
(
*
soraS3ProfilesStore
,
error
)
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySoraS3Profiles
)
if
err
==
nil
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
&
soraS3ProfilesStore
{},
nil
}
var
store
soraS3ProfilesStore
if
unmarshalErr
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
store
);
unmarshalErr
!=
nil
{
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal sora s3 profiles: %w"
,
unmarshalErr
)
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
normalized
:=
normalizeSoraS3ProfilesStore
(
store
)
return
&
normalized
,
nil
}
if
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"get sora s3 profiles: %w"
,
err
)
}
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
legacyErr
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
func
(
s
*
SettingService
)
persistSoraS3ProfilesStore
(
ctx
context
.
Context
,
store
*
soraS3ProfilesStore
)
error
{
if
store
==
nil
{
return
fmt
.
Errorf
(
"sora s3 profiles store cannot be nil"
)
}
normalized
:=
normalizeSoraS3ProfilesStore
(
*
store
)
data
,
err
:=
json
.
Marshal
(
normalized
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal sora s3 profiles: %w"
,
err
)
}
updates
:=
map
[
string
]
string
{
SettingKeySoraS3Profiles
:
string
(
data
),
}
active
:=
pickActiveSoraS3ProfileFromStore
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
if
active
==
nil
{
updates
[
SettingKeySoraS3Enabled
]
=
"false"
updates
[
SettingKeySoraS3Endpoint
]
=
""
updates
[
SettingKeySoraS3Region
]
=
""
updates
[
SettingKeySoraS3Bucket
]
=
""
updates
[
SettingKeySoraS3AccessKeyID
]
=
""
updates
[
SettingKeySoraS3Prefix
]
=
""
updates
[
SettingKeySoraS3ForcePathStyle
]
=
"false"
updates
[
SettingKeySoraS3CDNURL
]
=
""
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
"0"
updates
[
SettingKeySoraS3SecretAccessKey
]
=
""
}
else
{
updates
[
SettingKeySoraS3Enabled
]
=
strconv
.
FormatBool
(
active
.
Enabled
)
updates
[
SettingKeySoraS3Endpoint
]
=
strings
.
TrimSpace
(
active
.
Endpoint
)
updates
[
SettingKeySoraS3Region
]
=
strings
.
TrimSpace
(
active
.
Region
)
updates
[
SettingKeySoraS3Bucket
]
=
strings
.
TrimSpace
(
active
.
Bucket
)
updates
[
SettingKeySoraS3AccessKeyID
]
=
strings
.
TrimSpace
(
active
.
AccessKeyID
)
updates
[
SettingKeySoraS3Prefix
]
=
strings
.
TrimSpace
(
active
.
Prefix
)
updates
[
SettingKeySoraS3ForcePathStyle
]
=
strconv
.
FormatBool
(
active
.
ForcePathStyle
)
updates
[
SettingKeySoraS3CDNURL
]
=
strings
.
TrimSpace
(
active
.
CDNURL
)
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
strconv
.
FormatInt
(
maxInt64
(
active
.
DefaultStorageQuotaBytes
,
0
),
10
)
updates
[
SettingKeySoraS3SecretAccessKey
]
=
active
.
SecretAccessKey
}
if
err
:=
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
);
err
!=
nil
{
return
err
}
if
s
.
onUpdate
!=
nil
{
s
.
onUpdate
()
}
if
s
.
onS3Update
!=
nil
{
s
.
onS3Update
()
}
return
nil
}
func
(
s
*
SettingService
)
getLegacySoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
keys
:=
[]
string
{
SettingKeySoraS3Enabled
,
SettingKeySoraS3Endpoint
,
SettingKeySoraS3Region
,
SettingKeySoraS3Bucket
,
SettingKeySoraS3AccessKeyID
,
SettingKeySoraS3SecretAccessKey
,
SettingKeySoraS3Prefix
,
SettingKeySoraS3ForcePathStyle
,
SettingKeySoraS3CDNURL
,
SettingKeySoraDefaultStorageQuotaBytes
,
}
values
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get legacy sora s3 settings: %w"
,
err
)
}
result
:=
&
SoraS3Settings
{
Enabled
:
values
[
SettingKeySoraS3Enabled
]
==
"true"
,
Endpoint
:
values
[
SettingKeySoraS3Endpoint
],
Region
:
values
[
SettingKeySoraS3Region
],
Bucket
:
values
[
SettingKeySoraS3Bucket
],
AccessKeyID
:
values
[
SettingKeySoraS3AccessKeyID
],
SecretAccessKey
:
values
[
SettingKeySoraS3SecretAccessKey
],
SecretAccessKeyConfigured
:
values
[
SettingKeySoraS3SecretAccessKey
]
!=
""
,
Prefix
:
values
[
SettingKeySoraS3Prefix
],
ForcePathStyle
:
values
[
SettingKeySoraS3ForcePathStyle
]
==
"true"
,
CDNURL
:
values
[
SettingKeySoraS3CDNURL
],
}
if
v
,
parseErr
:=
strconv
.
ParseInt
(
values
[
SettingKeySoraDefaultStorageQuotaBytes
],
10
,
64
);
parseErr
==
nil
{
result
.
DefaultStorageQuotaBytes
=
v
}
return
result
,
nil
}
func
normalizeSoraS3ProfilesStore
(
store
soraS3ProfilesStore
)
soraS3ProfilesStore
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
store
.
Items
))
normalized
:=
soraS3ProfilesStore
{
ActiveProfileID
:
strings
.
TrimSpace
(
store
.
ActiveProfileID
),
Items
:
make
([]
soraS3ProfileStoreItem
,
0
,
len
(
store
.
Items
)),
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
item
.
ProfileID
=
strings
.
TrimSpace
(
item
.
ProfileID
)
if
item
.
ProfileID
==
""
{
item
.
ProfileID
=
fmt
.
Sprintf
(
"profile-%d"
,
idx
+
1
)
}
if
_
,
exists
:=
seen
[
item
.
ProfileID
];
exists
{
continue
}
seen
[
item
.
ProfileID
]
=
struct
{}{}
item
.
Name
=
strings
.
TrimSpace
(
item
.
Name
)
if
item
.
Name
==
""
{
item
.
Name
=
item
.
ProfileID
}
item
.
Endpoint
=
strings
.
TrimSpace
(
item
.
Endpoint
)
item
.
Region
=
strings
.
TrimSpace
(
item
.
Region
)
item
.
Bucket
=
strings
.
TrimSpace
(
item
.
Bucket
)
item
.
AccessKeyID
=
strings
.
TrimSpace
(
item
.
AccessKeyID
)
item
.
Prefix
=
strings
.
TrimSpace
(
item
.
Prefix
)
item
.
CDNURL
=
strings
.
TrimSpace
(
item
.
CDNURL
)
item
.
DefaultStorageQuotaBytes
=
maxInt64
(
item
.
DefaultStorageQuotaBytes
,
0
)
item
.
UpdatedAt
=
strings
.
TrimSpace
(
item
.
UpdatedAt
)
if
item
.
UpdatedAt
==
""
{
item
.
UpdatedAt
=
now
}
normalized
.
Items
=
append
(
normalized
.
Items
,
item
)
}
if
len
(
normalized
.
Items
)
==
0
{
normalized
.
ActiveProfileID
=
""
return
normalized
}
if
findSoraS3ProfileIndex
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
>=
0
{
return
normalized
}
normalized
.
ActiveProfileID
=
normalized
.
Items
[
0
]
.
ProfileID
return
normalized
}
func
convertSoraS3ProfilesStore
(
store
*
soraS3ProfilesStore
)
*
SoraS3ProfileList
{
if
store
==
nil
{
return
&
SoraS3ProfileList
{}
}
items
:=
make
([]
SoraS3Profile
,
0
,
len
(
store
.
Items
))
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
items
=
append
(
items
,
SoraS3Profile
{
ProfileID
:
item
.
ProfileID
,
Name
:
item
.
Name
,
IsActive
:
item
.
ProfileID
==
store
.
ActiveProfileID
,
Enabled
:
item
.
Enabled
,
Endpoint
:
item
.
Endpoint
,
Region
:
item
.
Region
,
Bucket
:
item
.
Bucket
,
AccessKeyID
:
item
.
AccessKeyID
,
SecretAccessKey
:
item
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
item
.
SecretAccessKey
!=
""
,
Prefix
:
item
.
Prefix
,
ForcePathStyle
:
item
.
ForcePathStyle
,
CDNURL
:
item
.
CDNURL
,
DefaultStorageQuotaBytes
:
item
.
DefaultStorageQuotaBytes
,
UpdatedAt
:
item
.
UpdatedAt
,
})
}
return
&
SoraS3ProfileList
{
ActiveProfileID
:
store
.
ActiveProfileID
,
Items
:
items
,
}
}
func
pickActiveSoraS3Profile
(
items
[]
SoraS3Profile
,
activeProfileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileByID
(
items
[]
SoraS3Profile
,
profileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
&
items
[
idx
]
}
}
return
nil
}
func
pickActiveSoraS3ProfileFromStore
(
items
[]
soraS3ProfileStoreItem
,
activeProfileID
string
)
*
soraS3ProfileStoreItem
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileIndex
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
int
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
idx
}
}
return
-
1
}
func
hasSoraS3ProfileID
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
bool
{
return
findSoraS3ProfileIndex
(
items
,
profileID
)
>=
0
}
func
isEmptyLegacySoraS3Settings
(
settings
*
SoraS3Settings
)
bool
{
if
settings
==
nil
{
return
true
}
if
settings
.
Enabled
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Endpoint
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Region
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Bucket
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
!=
""
{
return
false
}
if
settings
.
SecretAccessKey
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Prefix
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
CDNURL
)
!=
""
{
return
false
}
return
settings
.
DefaultStorageQuotaBytes
==
0
}
func
maxInt64
(
value
int64
,
min
int64
)
int64
{
if
value
<
min
{
return
min
}
return
value
}
backend/internal/service/settings_view.go
View file @
bb664d9b
...
...
@@ -39,6 +39,7 @@ type SystemSettings struct {
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
DefaultConcurrency
int
DefaultBalance
float64
...
...
@@ -81,11 +82,52 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
LinuxDoOAuthEnabled
bool
Version
string
}
// SoraS3Settings Sora S3 存储配置
type
SoraS3Settings
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type
SoraS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"-"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type
SoraS3ProfileList
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
SoraS3Profile
`json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type
StreamTimeoutSettings
struct
{
// Enabled 是否启用流超时处理
...
...
Prev
1
…
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment