Commit ab4e8b2c authored by QTom's avatar QTom
Browse files

fix(gateway): 防止 OpenAI Codex 跨用户串流

根因:多个用户共享同一 OAuth 账号时,conversation_id/session_id 头
未做用户隔离,导致上游 chatgpt.com 将不同用户的请求关联到同一会话。

HTTP SSE 修复:
- 新增 isolateOpenAISessionID(apiKeyID, raw),将 API Key ID 混入
  session 标识符(xxhash),确保不同 Key 的用户产生不同上游会话
- buildUpstreamRequest: OAuth 分支先 Del 客户端透传的 session 头,
  再用隔离值覆盖
- buildUpstreamRequestOpenAIPassthrough: 透传路径同样隔离
- ForwardAsAnthropic: Anthropic Messages 兼容路径同步修复
- buildOpenAIWSHeaders: WS 路径的 OAuth session 头同步隔离
parent 474165d7
...@@ -107,10 +107,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -107,10 +107,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
// Override session_id with a deterministic UUID derived from the sticky // Override session_id with a deterministic UUID derived from the isolated
// session key (buildUpstreamRequest may have set it to the raw value). // session key, ensuring different API keys produce different upstream sessions.
if promptCacheKey != "" { if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) apiKeyID := getAPIKeyIDFromContext(c)
upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey)))
} }
// 7. Send request // 7. Send request
......
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
...@@ -787,6 +788,20 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 { ...@@ -787,6 +788,20 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 {
return apiKey.ID return apiKey.ID
} }
// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符,
// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id,
// 到达上游的标识符也不同,防止跨用户会话碰撞。
func isolateOpenAISessionID(apiKeyID int64, raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
h := xxhash.New()
_, _ = fmt.Fprintf(h, "k%d:", apiKeyID)
_, _ = h.WriteString(raw)
return fmt.Sprintf("%016x", h.Sum64())
}
func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) {
if !result.Enabled { if !result.Enabled {
return return
...@@ -2501,13 +2516,17 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( ...@@ -2501,13 +2516,17 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID) req.Header.Set("chatgpt-account-id", chatgptAccountID)
} }
apiKeyID := getAPIKeyIDFromContext(c)
// 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。
clientSessionID := strings.TrimSpace(req.Header.Get("session_id"))
clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id"))
if isOpenAIResponsesCompactPath(c) { if isOpenAIResponsesCompactPath(c) {
req.Header.Set("accept", "application/json") req.Header.Set("accept", "application/json")
if req.Header.Get("version") == "" { if req.Header.Get("version") == "" {
req.Header.Set("version", codexCLIVersion) req.Header.Set("version", codexCLIVersion)
} }
if req.Header.Get("session_id") == "" { if clientSessionID == "" {
req.Header.Set("session_id", resolveOpenAICompactSessionID(c)) clientSessionID = resolveOpenAICompactSessionID(c)
} }
} else if req.Header.Get("accept") == "" { } else if req.Header.Get("accept") == "" {
req.Header.Set("accept", "text/event-stream") req.Header.Set("accept", "text/event-stream")
...@@ -2518,13 +2537,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( ...@@ -2518,13 +2537,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if req.Header.Get("originator") == "" { if req.Header.Get("originator") == "" {
req.Header.Set("originator", "codex_cli_rs") req.Header.Set("originator", "codex_cli_rs")
} }
if promptCacheKey != "" { // 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。
if req.Header.Get("conversation_id") == "" { if clientSessionID == "" {
req.Header.Set("conversation_id", promptCacheKey) clientSessionID = promptCacheKey
} }
if req.Header.Get("session_id") == "" { if clientConversationID == "" {
req.Header.Set("session_id", promptCacheKey) clientConversationID = promptCacheKey
} }
if clientSessionID != "" {
req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID))
}
if clientConversationID != "" {
req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID))
} }
} }
...@@ -2887,22 +2911,27 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -2887,22 +2911,27 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
} }
} }
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
// 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。
req.Header.Del("conversation_id")
req.Header.Del("session_id")
req.Header.Set("OpenAI-Beta", "responses=experimental") req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
apiKeyID := getAPIKeyIDFromContext(c)
if isOpenAIResponsesCompactPath(c) { if isOpenAIResponsesCompactPath(c) {
req.Header.Set("accept", "application/json") req.Header.Set("accept", "application/json")
if req.Header.Get("version") == "" { if req.Header.Get("version") == "" {
req.Header.Set("version", codexCLIVersion) req.Header.Set("version", codexCLIVersion)
} }
if req.Header.Get("session_id") == "" { compactSession := resolveOpenAICompactSessionID(c)
req.Header.Set("session_id", resolveOpenAICompactSessionID(c)) req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, compactSession))
}
} else { } else {
req.Header.Set("accept", "text/event-stream") req.Header.Set("accept", "text/event-stream")
} }
if promptCacheKey != "" { if promptCacheKey != "" {
req.Header.Set("conversation_id", promptCacheKey) isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey)
req.Header.Set("session_id", promptCacheKey) req.Header.Set("conversation_id", isolated)
req.Header.Set("session_id", isolated)
} }
} }
......
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsolateOpenAISessionID(t *testing.T) {
t.Run("empty_raw_returns_empty", func(t *testing.T) {
assert.Equal(t, "", isolateOpenAISessionID(1, ""))
assert.Equal(t, "", isolateOpenAISessionID(1, " "))
})
t.Run("deterministic", func(t *testing.T) {
a := isolateOpenAISessionID(42, "sess_abc123")
b := isolateOpenAISessionID(42, "sess_abc123")
assert.Equal(t, a, b)
})
t.Run("different_apiKeyID_different_result", func(t *testing.T) {
a := isolateOpenAISessionID(1, "same_session")
b := isolateOpenAISessionID(2, "same_session")
require.NotEqual(t, a, b, "不同 API Key 使用相同 session_id 应产生不同隔离值")
})
t.Run("different_raw_different_result", func(t *testing.T) {
a := isolateOpenAISessionID(1, "session_a")
b := isolateOpenAISessionID(1, "session_b")
require.NotEqual(t, a, b)
})
t.Run("format_is_16_hex_chars", func(t *testing.T) {
result := isolateOpenAISessionID(99, "test_session")
assert.Len(t, result, 16, "应为 16 字符的 hex 字符串")
for _, ch := range result {
assert.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'),
"应仅包含 hex 字符: %c", ch)
}
})
t.Run("zero_apiKeyID_still_works", func(t *testing.T) {
result := isolateOpenAISessionID(0, "session")
assert.NotEmpty(t, result)
// apiKeyID=0 与 apiKeyID=1 应产生不同结果
other := isolateOpenAISessionID(1, "session")
assert.NotEqual(t, result, other)
})
}
...@@ -1124,11 +1124,22 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( ...@@ -1124,11 +1124,22 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
headers.Set("accept-language", v) headers.Set("accept-language", v)
} }
} }
if sessionResolution.SessionID != "" { // OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。
headers.Set("session_id", sessionResolution.SessionID) if account != nil && account.Type == AccountTypeOAuth {
} apiKeyID := getAPIKeyIDFromContext(c)
if sessionResolution.ConversationID != "" { if sessionResolution.SessionID != "" {
headers.Set("conversation_id", sessionResolution.ConversationID) headers.Set("session_id", isolateOpenAISessionID(apiKeyID, sessionResolution.SessionID))
}
if sessionResolution.ConversationID != "" {
headers.Set("conversation_id", isolateOpenAISessionID(apiKeyID, sessionResolution.ConversationID))
}
} else {
if sessionResolution.SessionID != "" {
headers.Set("session_id", sessionResolution.SessionID)
}
if sessionResolution.ConversationID != "" {
headers.Set("conversation_id", sessionResolution.ConversationID)
}
} }
if state := strings.TrimSpace(turnState); state != "" { if state := strings.TrimSpace(turnState); state != "" {
headers.Set(openAIWSTurnStateHeader, state) headers.Set(openAIWSTurnStateHeader, state)
......
...@@ -454,8 +454,10 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T ...@@ -454,8 +454,10 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T
require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) // OAuth 账号的 session_id/conversation_id 应被 isolateOpenAISessionID 隔离,
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) // 测试中未设置 api_key 到 context,apiKeyID=0。
require.Equal(t, isolateOpenAISessionID(0, "sess-oauth-1"), captureDialer.lastHeaders.Get("session_id"))
require.Equal(t, isolateOpenAISessionID(0, "conv-oauth-1"), captureDialer.lastHeaders.Get("conversation_id"))
} }
func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) { func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) {
...@@ -596,7 +598,8 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK ...@@ -596,7 +598,8 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, "resp_prompt_cache_key", result.RequestID) require.Equal(t, "resp_prompt_cache_key", result.RequestID)
require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) // OAuth 账号的 session_id 应被 isolateOpenAISessionID 隔离(apiKeyID=0,未在 context 设置)。
require.Equal(t, isolateOpenAISessionID(0, "pcache_123"), captureDialer.lastHeaders.Get("session_id"))
require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
require.NotNil(t, captureConn.lastWrite) require.NotNil(t, captureConn.lastWrite)
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment