Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6447be45
Unverified
Commit
6447be45
authored
Mar 16, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 16, 2026
Browse files
Merge pull request #1047 from DaydreamCoding/fix/codex-stream-isolation
fix(gateway): 防止 OpenAI Codex 跨用户串流 + WS 连接池条件式 MarkBroken
parents
474165d7
3741617e
Changes
5
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/openai_gateway_messages.go
View file @
6447be45
...
@@ -107,10 +107,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
...
@@ -107,10 +107,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return
nil
,
fmt
.
Errorf
(
"build upstream request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"build upstream request: %w"
,
err
)
}
}
// Override session_id with a deterministic UUID derived from the
sticky
// Override session_id with a deterministic UUID derived from the
isolated
// session key
(buildUpstreamRequest may have set it to the raw value)
.
// session key
, ensuring different API keys produce different upstream sessions
.
if
promptCacheKey
!=
""
{
if
promptCacheKey
!=
""
{
upstreamReq
.
Header
.
Set
(
"session_id"
,
generateSessionUUID
(
promptCacheKey
))
apiKeyID
:=
getAPIKeyIDFromContext
(
c
)
upstreamReq
.
Header
.
Set
(
"session_id"
,
generateSessionUUID
(
isolateOpenAISessionID
(
apiKeyID
,
promptCacheKey
)))
}
}
// 7. Send request
// 7. Send request
...
...
backend/internal/service/openai_gateway_service.go
View file @
6447be45
...
@@ -24,6 +24,7 @@ import (
...
@@ -24,6 +24,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/gjson"
...
@@ -787,6 +788,20 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 {
...
@@ -787,6 +788,20 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 {
return
apiKey
.
ID
return
apiKey
.
ID
}
}
// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符,
// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id,
// 到达上游的标识符也不同,防止跨用户会话碰撞。
func
isolateOpenAISessionID
(
apiKeyID
int64
,
raw
string
)
string
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
""
}
h
:=
xxhash
.
New
()
_
,
_
=
fmt
.
Fprintf
(
h
,
"k%d:"
,
apiKeyID
)
_
,
_
=
h
.
WriteString
(
raw
)
return
fmt
.
Sprintf
(
"%016x"
,
h
.
Sum64
())
}
func
logCodexCLIOnlyDetection
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
apiKeyID
int64
,
result
CodexClientRestrictionDetectionResult
,
body
[]
byte
)
{
func
logCodexCLIOnlyDetection
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
apiKeyID
int64
,
result
CodexClientRestrictionDetectionResult
,
body
[]
byte
)
{
if
!
result
.
Enabled
{
if
!
result
.
Enabled
{
return
return
...
@@ -2501,13 +2516,17 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
...
@@ -2501,13 +2516,17 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if
chatgptAccountID
:=
account
.
GetChatGPTAccountID
();
chatgptAccountID
!=
""
{
if
chatgptAccountID
:=
account
.
GetChatGPTAccountID
();
chatgptAccountID
!=
""
{
req
.
Header
.
Set
(
"chatgpt-account-id"
,
chatgptAccountID
)
req
.
Header
.
Set
(
"chatgpt-account-id"
,
chatgptAccountID
)
}
}
apiKeyID
:=
getAPIKeyIDFromContext
(
c
)
// 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。
clientSessionID
:=
strings
.
TrimSpace
(
req
.
Header
.
Get
(
"session_id"
))
clientConversationID
:=
strings
.
TrimSpace
(
req
.
Header
.
Get
(
"conversation_id"
))
if
isOpenAIResponsesCompactPath
(
c
)
{
if
isOpenAIResponsesCompactPath
(
c
)
{
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
req
.
Header
.
Get
(
"version"
)
==
""
{
if
req
.
Header
.
Get
(
"version"
)
==
""
{
req
.
Header
.
Set
(
"version"
,
codexCLIVersion
)
req
.
Header
.
Set
(
"version"
,
codexCLIVersion
)
}
}
if
req
.
Header
.
Get
(
"s
ession
_id"
)
==
""
{
if
clientS
ession
ID
==
""
{
req
.
Header
.
Set
(
"session_id"
,
resolveOpenAICompactSessionID
(
c
)
)
clientSessionID
=
resolveOpenAICompactSessionID
(
c
)
}
}
}
else
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
}
else
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
req
.
Header
.
Set
(
"accept"
,
"text/event-stream"
)
req
.
Header
.
Set
(
"accept"
,
"text/event-stream"
)
...
@@ -2518,13 +2537,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
...
@@ -2518,13 +2537,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if
req
.
Header
.
Get
(
"originator"
)
==
""
{
if
req
.
Header
.
Get
(
"originator"
)
==
""
{
req
.
Header
.
Set
(
"originator"
,
"codex_cli_rs"
)
req
.
Header
.
Set
(
"originator"
,
"codex_cli_rs"
)
}
}
if
promptCacheKey
!=
""
{
// 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。
if
req
.
Header
.
Get
(
"conversation_id"
)
==
""
{
if
clientSessionID
==
""
{
req
.
Header
.
Set
(
"conversation_id"
,
promptCacheKey
)
clientSessionID
=
promptCacheKey
}
}
if
req
.
Header
.
Get
(
"session_id"
)
==
""
{
if
clientConversationID
==
""
{
req
.
Header
.
Set
(
"session_id"
,
promptCacheKey
)
clientConversationID
=
promptCacheKey
}
}
if
clientSessionID
!=
""
{
req
.
Header
.
Set
(
"session_id"
,
isolateOpenAISessionID
(
apiKeyID
,
clientSessionID
))
}
if
clientConversationID
!=
""
{
req
.
Header
.
Set
(
"conversation_id"
,
isolateOpenAISessionID
(
apiKeyID
,
clientConversationID
))
}
}
}
}
...
@@ -2887,22 +2911,27 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
...
@@ -2887,22 +2911,27 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
}
}
}
if
account
.
Type
==
AccountTypeOAuth
{
if
account
.
Type
==
AccountTypeOAuth
{
// 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。
req
.
Header
.
Del
(
"conversation_id"
)
req
.
Header
.
Del
(
"session_id"
)
req
.
Header
.
Set
(
"OpenAI-Beta"
,
"responses=experimental"
)
req
.
Header
.
Set
(
"OpenAI-Beta"
,
"responses=experimental"
)
req
.
Header
.
Set
(
"originator"
,
resolveOpenAIUpstreamOriginator
(
c
,
isCodexCLI
))
req
.
Header
.
Set
(
"originator"
,
resolveOpenAIUpstreamOriginator
(
c
,
isCodexCLI
))
apiKeyID
:=
getAPIKeyIDFromContext
(
c
)
if
isOpenAIResponsesCompactPath
(
c
)
{
if
isOpenAIResponsesCompactPath
(
c
)
{
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
req
.
Header
.
Get
(
"version"
)
==
""
{
if
req
.
Header
.
Get
(
"version"
)
==
""
{
req
.
Header
.
Set
(
"version"
,
codexCLIVersion
)
req
.
Header
.
Set
(
"version"
,
codexCLIVersion
)
}
}
if
req
.
Header
.
Get
(
"session_id"
)
==
""
{
compactSession
:=
resolveOpenAICompactSessionID
(
c
)
req
.
Header
.
Set
(
"session_id"
,
resolveOpenAICompactSessionID
(
c
))
req
.
Header
.
Set
(
"session_id"
,
isolateOpenAISessionID
(
apiKeyID
,
compactSession
))
}
}
else
{
}
else
{
req
.
Header
.
Set
(
"accept"
,
"text/event-stream"
)
req
.
Header
.
Set
(
"accept"
,
"text/event-stream"
)
}
}
if
promptCacheKey
!=
""
{
if
promptCacheKey
!=
""
{
req
.
Header
.
Set
(
"conversation_id"
,
promptCacheKey
)
isolated
:=
isolateOpenAISessionID
(
apiKeyID
,
promptCacheKey
)
req
.
Header
.
Set
(
"session_id"
,
promptCacheKey
)
req
.
Header
.
Set
(
"conversation_id"
,
isolated
)
req
.
Header
.
Set
(
"session_id"
,
isolated
)
}
}
}
}
...
...
backend/internal/service/openai_gateway_service_session_isolation_test.go
0 → 100644
View file @
6447be45
package
service
import
(
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestIsolateOpenAISessionID
(
t
*
testing
.
T
)
{
t
.
Run
(
"empty_raw_returns_empty"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
""
,
isolateOpenAISessionID
(
1
,
""
))
assert
.
Equal
(
t
,
""
,
isolateOpenAISessionID
(
1
,
" "
))
})
t
.
Run
(
"deterministic"
,
func
(
t
*
testing
.
T
)
{
a
:=
isolateOpenAISessionID
(
42
,
"sess_abc123"
)
b
:=
isolateOpenAISessionID
(
42
,
"sess_abc123"
)
assert
.
Equal
(
t
,
a
,
b
)
})
t
.
Run
(
"different_apiKeyID_different_result"
,
func
(
t
*
testing
.
T
)
{
a
:=
isolateOpenAISessionID
(
1
,
"same_session"
)
b
:=
isolateOpenAISessionID
(
2
,
"same_session"
)
require
.
NotEqual
(
t
,
a
,
b
,
"不同 API Key 使用相同 session_id 应产生不同隔离值"
)
})
t
.
Run
(
"different_raw_different_result"
,
func
(
t
*
testing
.
T
)
{
a
:=
isolateOpenAISessionID
(
1
,
"session_a"
)
b
:=
isolateOpenAISessionID
(
1
,
"session_b"
)
require
.
NotEqual
(
t
,
a
,
b
)
})
t
.
Run
(
"format_is_16_hex_chars"
,
func
(
t
*
testing
.
T
)
{
result
:=
isolateOpenAISessionID
(
99
,
"test_session"
)
assert
.
Len
(
t
,
result
,
16
,
"应为 16 字符的 hex 字符串"
)
for
_
,
ch
:=
range
result
{
assert
.
True
(
t
,
(
ch
>=
'0'
&&
ch
<=
'9'
)
||
(
ch
>=
'a'
&&
ch
<=
'f'
),
"应仅包含 hex 字符: %c"
,
ch
)
}
})
t
.
Run
(
"zero_apiKeyID_still_works"
,
func
(
t
*
testing
.
T
)
{
result
:=
isolateOpenAISessionID
(
0
,
"session"
)
assert
.
NotEmpty
(
t
,
result
)
// apiKeyID=0 与 apiKeyID=1 应产生不同结果
other
:=
isolateOpenAISessionID
(
1
,
"session"
)
assert
.
NotEqual
(
t
,
result
,
other
)
})
}
backend/internal/service/openai_ws_forwarder.go
View file @
6447be45
...
@@ -1124,11 +1124,22 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
...
@@ -1124,11 +1124,22 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
headers
.
Set
(
"accept-language"
,
v
)
headers
.
Set
(
"accept-language"
,
v
)
}
}
}
}
if
sessionResolution
.
SessionID
!=
""
{
// OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。
headers
.
Set
(
"session_id"
,
sessionResolution
.
SessionID
)
if
account
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
{
}
apiKeyID
:=
getAPIKeyIDFromContext
(
c
)
if
sessionResolution
.
ConversationID
!=
""
{
if
sessionResolution
.
SessionID
!=
""
{
headers
.
Set
(
"conversation_id"
,
sessionResolution
.
ConversationID
)
headers
.
Set
(
"session_id"
,
isolateOpenAISessionID
(
apiKeyID
,
sessionResolution
.
SessionID
))
}
if
sessionResolution
.
ConversationID
!=
""
{
headers
.
Set
(
"conversation_id"
,
isolateOpenAISessionID
(
apiKeyID
,
sessionResolution
.
ConversationID
))
}
}
else
{
if
sessionResolution
.
SessionID
!=
""
{
headers
.
Set
(
"session_id"
,
sessionResolution
.
SessionID
)
}
if
sessionResolution
.
ConversationID
!=
""
{
headers
.
Set
(
"conversation_id"
,
sessionResolution
.
ConversationID
)
}
}
}
if
state
:=
strings
.
TrimSpace
(
turnState
);
state
!=
""
{
if
state
:=
strings
.
TrimSpace
(
turnState
);
state
!=
""
{
headers
.
Set
(
openAIWSTurnStateHeader
,
state
)
headers
.
Set
(
openAIWSTurnStateHeader
,
state
)
...
@@ -1859,7 +1870,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
...
@@ -1859,7 +1870,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
}
return
nil
,
wrapOpenAIWSFallback
(
classifyOpenAIWSAcquireError
(
err
),
err
)
return
nil
,
wrapOpenAIWSFallback
(
classifyOpenAIWSAcquireError
(
err
),
err
)
}
}
defer
lease
.
Release
()
// cleanExit 标记正常终端事件退出,此时上游不会再发送帧,连接可安全归还复用。
// 所有异常路径(读写错误、error 事件等)已在各自分支中提前调用 MarkBroken,
// 因此 defer 中只需处理正常退出时不 MarkBroken 即可。
cleanExit
:=
false
defer
func
()
{
if
!
cleanExit
{
lease
.
MarkBroken
()
}
lease
.
Release
()
}()
connID
:=
strings
.
TrimSpace
(
lease
.
ConnID
())
connID
:=
strings
.
TrimSpace
(
lease
.
ConnID
())
logOpenAIWSModeDebug
(
logOpenAIWSModeDebug
(
"connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v"
,
"connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v"
,
...
@@ -2237,6 +2257,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
...
@@ -2237,6 +2257,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
}
if
isTerminalEvent
{
if
isTerminalEvent
{
cleanExit
=
true
break
break
}
}
}
}
...
@@ -2972,12 +2993,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -2972,12 +2993,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
pinnedSessionConnID
=
connID
pinnedSessionConnID
=
connID
}
}
}
}
// lastTurnClean 标记最后一轮 sendAndRelay 是否正常完成(收到终端事件且客户端未断连)。
// 所有异常路径(读写错误、error 事件、客户端断连)已在各自分支或上层(L3403)中 MarkBroken,
// 因此 releaseSessionLease 中只需在非正常结束时 MarkBroken。
lastTurnClean
:=
false
releaseSessionLease
:=
func
()
{
releaseSessionLease
:=
func
()
{
if
sessionLease
==
nil
{
if
sessionLease
==
nil
{
return
return
}
}
if
dedicatedMode
{
if
!
lastTurnClean
{
// dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
sessionLease
.
MarkBroken
()
sessionLease
.
MarkBroken
()
}
}
unpinSessionConn
(
sessionConnID
)
unpinSessionConn
(
sessionConnID
)
...
@@ -3372,6 +3396,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -3372,6 +3396,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
result
,
relayErr
:=
sendAndRelay
(
turn
,
sessionLease
,
currentPayload
,
currentPayloadBytes
,
currentOriginalModel
)
result
,
relayErr
:=
sendAndRelay
(
turn
,
sessionLease
,
currentPayload
,
currentPayloadBytes
,
currentOriginalModel
)
if
relayErr
!=
nil
{
if
relayErr
!=
nil
{
lastTurnClean
=
false
if
recoverIngressPrevResponseNotFound
(
relayErr
,
turn
,
connID
)
{
if
recoverIngressPrevResponseNotFound
(
relayErr
,
turn
,
connID
)
{
continue
continue
}
}
...
@@ -3391,6 +3416,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -3391,6 +3416,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnRetry
=
0
turnRetry
=
0
turnPrevRecoveryTried
=
false
turnPrevRecoveryTried
=
false
lastTurnFinishedAt
=
time
.
Now
()
lastTurnFinishedAt
=
time
.
Now
()
lastTurnClean
=
true
if
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
if
hooks
!=
nil
&&
hooks
.
AfterTurn
!=
nil
{
hooks
.
AfterTurn
(
turn
,
result
,
nil
)
hooks
.
AfterTurn
(
turn
,
result
,
nil
)
}
}
...
...
backend/internal/service/openai_ws_forwarder_success_test.go
View file @
6447be45
...
@@ -380,7 +380,8 @@ func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
...
@@ -380,7 +380,8 @@ func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
require
.
True
(
t
,
strings
.
HasPrefix
(
result
.
RequestID
,
"resp_reuse_"
))
require
.
True
(
t
,
strings
.
HasPrefix
(
result
.
RequestID
,
"resp_reuse_"
))
}
}
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
(),
"多个客户端请求应复用账号连接池而不是 1:1 对等建链"
)
// 条件式 MarkBroken:正常终端事件退出后连接归还复用,不再无条件销毁。
require
.
Equal
(
t
,
int64
(
1
),
upgradeCount
.
Load
(),
"正常完成后连接应归还复用,不应每次新建"
)
metrics
:=
svc
.
SnapshotOpenAIWSPoolMetrics
()
metrics
:=
svc
.
SnapshotOpenAIWSPoolMetrics
()
require
.
GreaterOrEqual
(
t
,
metrics
.
AcquireReuseTotal
,
int64
(
1
))
require
.
GreaterOrEqual
(
t
,
metrics
.
AcquireReuseTotal
,
int64
(
1
))
require
.
GreaterOrEqual
(
t
,
metrics
.
ConnPickTotal
,
int64
(
1
))
require
.
GreaterOrEqual
(
t
,
metrics
.
ConnPickTotal
,
int64
(
1
))
...
@@ -454,8 +455,10 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T
...
@@ -454,8 +455,10 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Exists
(),
"WSv2 payload 应保留 stream 字段"
)
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Exists
(),
"WSv2 payload 应保留 stream 字段"
)
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Bool
(),
"OAuth Codex 规范化后应强制 stream=true"
)
require
.
True
(
t
,
gjson
.
Get
(
requestJSON
,
"stream"
)
.
Bool
(),
"OAuth Codex 规范化后应强制 stream=true"
)
require
.
Equal
(
t
,
openAIWSBetaV2Value
,
captureDialer
.
lastHeaders
.
Get
(
"OpenAI-Beta"
))
require
.
Equal
(
t
,
openAIWSBetaV2Value
,
captureDialer
.
lastHeaders
.
Get
(
"OpenAI-Beta"
))
require
.
Equal
(
t
,
"sess-oauth-1"
,
captureDialer
.
lastHeaders
.
Get
(
"session_id"
))
// OAuth 账号的 session_id/conversation_id 应被 isolateOpenAISessionID 隔离,
require
.
Equal
(
t
,
"conv-oauth-1"
,
captureDialer
.
lastHeaders
.
Get
(
"conversation_id"
))
// 测试中未设置 api_key 到 context,apiKeyID=0。
require
.
Equal
(
t
,
isolateOpenAISessionID
(
0
,
"sess-oauth-1"
),
captureDialer
.
lastHeaders
.
Get
(
"session_id"
))
require
.
Equal
(
t
,
isolateOpenAISessionID
(
0
,
"conv-oauth-1"
),
captureDialer
.
lastHeaders
.
Get
(
"conversation_id"
))
}
}
func
TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility
(
t
*
testing
.
T
)
{
...
@@ -596,7 +599,8 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
...
@@ -596,7 +599,8 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"resp_prompt_cache_key"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"resp_prompt_cache_key"
,
result
.
RequestID
)
require
.
Equal
(
t
,
"pcache_123"
,
captureDialer
.
lastHeaders
.
Get
(
"session_id"
))
// OAuth 账号的 session_id 应被 isolateOpenAISessionID 隔离(apiKeyID=0,未在 context 设置)。
require
.
Equal
(
t
,
isolateOpenAISessionID
(
0
,
"pcache_123"
),
captureDialer
.
lastHeaders
.
Get
(
"session_id"
))
require
.
Empty
(
t
,
captureDialer
.
lastHeaders
.
Get
(
"conversation_id"
))
require
.
Empty
(
t
,
captureDialer
.
lastHeaders
.
Get
(
"conversation_id"
))
require
.
NotNil
(
t
,
captureConn
.
lastWrite
)
require
.
NotNil
(
t
,
captureConn
.
lastWrite
)
require
.
True
(
t
,
gjson
.
Get
(
requestToJSONString
(
captureConn
.
lastWrite
),
"stream"
)
.
Exists
())
require
.
True
(
t
,
gjson
.
Get
(
requestToJSONString
(
captureConn
.
lastWrite
),
"stream"
)
.
Exists
())
...
@@ -961,6 +965,10 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t
...
@@ -961,6 +965,10 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t
require
.
NotNil
(
t
,
result1
)
require
.
NotNil
(
t
,
result1
)
require
.
Equal
(
t
,
"resp_meta_1"
,
result1
.
RequestID
)
require
.
Equal
(
t
,
"resp_meta_1"
,
result1
.
RequestID
)
require
.
Len
(
t
,
captureConn
.
writes
,
1
)
firstWrite
:=
requestToJSONString
(
captureConn
.
writes
[
0
])
require
.
Equal
(
t
,
"turn_meta_payload_1"
,
gjson
.
Get
(
firstWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
rec2
:=
httptest
.
NewRecorder
()
rec2
:=
httptest
.
NewRecorder
()
c2
,
_
:=
gin
.
CreateTestContext
(
rec2
)
c2
,
_
:=
gin
.
CreateTestContext
(
rec2
)
c2
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
c2
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
...
@@ -974,7 +982,7 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t
...
@@ -974,7 +982,7 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t
require
.
Equal
(
t
,
1
,
captureDialer
.
DialCount
(),
"同一账号两轮请求应复用同一 WS 连接"
)
require
.
Equal
(
t
,
1
,
captureDialer
.
DialCount
(),
"同一账号两轮请求应复用同一 WS 连接"
)
require
.
Len
(
t
,
captureConn
.
writes
,
2
)
require
.
Len
(
t
,
captureConn
.
writes
,
2
)
firstWrite
:
=
requestToJSONString
(
captureConn
.
writes
[
0
])
firstWrite
=
requestToJSONString
(
captureConn
.
writes
[
0
])
secondWrite
:=
requestToJSONString
(
captureConn
.
writes
[
1
])
secondWrite
:=
requestToJSONString
(
captureConn
.
writes
[
1
])
require
.
Equal
(
t
,
"turn_meta_payload_1"
,
gjson
.
Get
(
firstWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
require
.
Equal
(
t
,
"turn_meta_payload_1"
,
gjson
.
Get
(
firstWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
require
.
Equal
(
t
,
"turn_meta_payload_2"
,
gjson
.
Get
(
secondWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
require
.
Equal
(
t
,
"turn_meta_payload_2"
,
gjson
.
Get
(
secondWrite
,
"client_metadata.x-codex-turn-metadata"
)
.
String
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment