"frontend/src/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "bda7c39e55fef97b489d959b684321513fcbf13d"
Unverified Commit fa68cbad authored by InCerryGit's avatar InCerryGit Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 995ef134 0f033930
...@@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) { ...@@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
require.NotEmpty(t, block1["text"]) require.NotEmpty(t, block1["text"])
} }
func TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult(t *testing.T) {
// Empty text blocks nested inside tool_result content should also be stripped
input := []byte(`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":"valid result"},
{"type":"text","text":""}
]}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
toolResult := content0[0].(map[string]any)
require.Equal(t, "tool_result", toolResult["type"])
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 1)
require.Equal(t, "valid result", nestedContent[0].(map[string]any)["text"])
}
func TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice(t *testing.T) {
// If all nested content blocks in tool_result are empty text, content becomes empty slice
input := []byte(`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":""}
]},
{"type":"text","text":"hello"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 2)
toolResult := content0[0].(map[string]any)
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 0)
}
func TestStripEmptyTextBlocks(t *testing.T) {
t.Run("strips top-level empty text", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
require.Len(t, content, 1)
require.Equal(t, "hello", content[0].(map[string]any)["text"])
})
t.Run("strips nested empty text in tool_result", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
toolResult := content[0].(map[string]any)
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 1)
require.Equal(t, "ok", nestedContent[0].(map[string]any)["text"])
})
t.Run("no-op when no empty text", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := StripEmptyTextBlocks(input)
require.Equal(t, input, out)
})
t.Run("preserves non-map blocks in content", func(t *testing.T) {
// tool_result content can be a string; non-map blocks should pass through unchanged
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
require.Len(t, content, 1)
toolResult := content[0].(map[string]any)
require.Equal(t, "tool_result", toolResult["type"])
require.Equal(t, "string content", toolResult["content"])
})
t.Run("handles deeply nested tool_result", func(t *testing.T) {
// Recursive: tool_result containing another tool_result with empty text
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
outer := content[0].(map[string]any)
innerContent := outer["content"].([]any)
inner := innerContent[0].(map[string]any)
deepContent := inner["content"].([]any)
require.Len(t, deepContent, 1)
require.Equal(t, "deep", deepContent[0].(map[string]any)["text"])
})
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) { func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged // Non-empty text blocks should pass through unchanged
input := []byte(`{ input := []byte(`{
......
...@@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed.SessionContext != nil { if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP) _, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":") _, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent) _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent))
_, _ = combined.WriteString(":") _, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|") _, _ = combined.WriteString("|")
...@@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 调试日志:记录即将转发的账号信息 // 调试日志:记录即将转发的账号信息
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body) setOpsUpstreamRequestBody(c, body)
...@@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
}) })
...@@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "signature_error", Kind: "signature_error",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
Detail: func() string { Detail: func() string {
...@@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode, UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"), UpstreamRequestID: retryResp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(retryReq.URL.String()),
Kind: "signature_retry_thinking", Kind: "signature_retry_thinking",
Message: extractUpstreamErrorMessage(retryRespBody), Message: extractUpstreamErrorMessage(retryRespBody),
Detail: func() string { Detail: func() string {
...@@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(retryReq2.URL.String()),
Kind: "signature_retry_tools_request_error", Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
}) })
...@@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "budget_constraint_error", Kind: "budget_constraint_error",
Message: errMsg, Message: errMsg,
Detail: func() string { Detail: func() string {
...@@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
Detail: func() string { Detail: func() string {
...@@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( ...@@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
if c != nil { if c != nil {
c.Set("anthropic_passthrough", true) c.Set("anthropic_passthrough", true)
} }
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
input.Body = StripEmptyTextBlocks(input.Body)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, input.Body) setOpsUpstreamRequestBody(c, input.Body)
...@@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( ...@@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
...@@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( ...@@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "retry", Kind: "retry",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
...@@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream( ...@@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
}) })
...@@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream( ...@@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
Detail: func() string { Detail: func() string {
...@@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
...@@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex ...@@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "request_error", Kind: "request_error",
Message: sanitizeUpstreamErrorMessage(err.Error()), Message: sanitizeUpstreamErrorMessage(err.Error()),
...@@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex ...@@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "http_error", Kind: "http_error",
Message: upstreamMsg, Message: upstreamMsg,
......
...@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error ...@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
......
...@@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string { ...@@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
// 返回 16 字符的 Base64 编码的 SHA256 前缀 // 返回 16 字符的 Base64 编码的 SHA256 前缀
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string { func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
// 组合所有标识符 // 组合所有标识符
normalizedUserAgent := NormalizeSessionUserAgent(userAgent)
combined := strconv.FormatInt(userID, 10) + ":" + combined := strconv.FormatInt(userID, 10) + ":" +
strconv.FormatInt(apiKeyID, 10) + ":" + strconv.FormatInt(apiKeyID, 10) + ":" +
ip + ":" + ip + ":" +
userAgent + ":" + normalizedUserAgent + ":" +
platform + ":" + platform + ":" +
model model
......
...@@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { ...@@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
} }
} }
func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro")
if hash1 != hash2 {
t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
}
}
func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro")
if hash1 != hash2 {
t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
}
}
func TestParseGeminiSessionValue(t *testing.T) { func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
......
...@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if tierID != "" { if tierID != "" {
account.Credentials["tier_id"] = tierID account.Credentials["tier_id"] = tierID
} }
_ = p.accountRepo.Update(ctx, account) _ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials)
} }
} }
......
...@@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { ...@@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
} }
func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0"))
h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1"))
require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash")
}
func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0"))
h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1"))
require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash")
}
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
......
...@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ...@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB // 5. 设置版本号 + 更新 DB
if newCredentials != nil { if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli() newCredentials["_token_version"] = time.Now().UnixMilli()
freshAccount.Credentials = newCredentials if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
slog.Error("oauth_refresh_update_failed", slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID, "account_id", freshAccount.ID,
"error", updateErr, "error", updateErr,
......
...@@ -16,10 +16,11 @@ import ( ...@@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. // refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type refreshAPIAccountRepo struct { type refreshAPIAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
account *Account // returned by GetByID account *Account // returned by GetByID
getByIDErr error getByIDErr error
updateErr error updateErr error
updateCalls int updateCalls int
updateCredentialsCalls int
} }
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
...@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { ...@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return r.updateErr return r.updateErr
} }
func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
if r.account == nil || r.account.ID != id {
r.account = &Account{ID: id}
}
r.account.Credentials = cloneCredentials(credentials)
return nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. // refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type refreshAPIExecutorStub struct { type refreshAPIExecutorStub struct {
needsRefresh bool needsRefresh bool
...@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) { ...@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require.Equal(t, "new-token", result.NewCredentials["access_token"]) require.Equal(t, "new-token", result.NewCredentials["access_token"])
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
require.Equal(t, 1, repo.updateCalls) // DB updated require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 1, cache.releaseCalls) // lock released require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, executor.refreshCalls) require.Equal(t, 1, executor.refreshCalls)
} }
func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) {
resetAt := time.Now().Add(45 * time.Minute)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "safe-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotNil(t, repo.account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second)
}
func TestRefreshIfNeeded_LockHeld(t *testing.T) { func TestRefreshIfNeeded_LockHeld(t *testing.T) {
account := &Account{ID: 2, Platform: PlatformAnthropic} account := &Account{ID: 2, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account} repo := &refreshAPIAccountRepo{account: account}
...@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) { ...@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.Contains(t, err.Error(), "invalid_grant") require.Contains(t, err.Error(), "invalid_grant")
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
} }
...@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) { ...@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result := MergeCredentials(old, new) result := MergeCredentials(old, new)
require.Equal(t, "new-token", result["access_token"]) // overridden require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
} }
// ========== BuildClaudeAccountCredentials tests ========== // ========== BuildClaudeAccountCredentials tests ==========
......
...@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( ...@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil return nil, nil
} }
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired { if acquireErr == nil && result.Acquired {
...@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ...@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue continue
} }
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil { if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr return nil, len(candidates), topK, loadSkew, acquireErr
......
...@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa ...@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID) require.Equal(t, int64(32002), account.ID)
} }
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(10103)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(33002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10104)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{stalePrimary, staleSecondary},
accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(34002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(9) groupID := int64(9)
......
...@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID ...@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil return nil
} }
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号 // 刷新会话 TTL 并返回账号
// Refresh session TTL and return account // Refresh session TTL and return account
...@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ ...@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil { if fresh == nil {
continue continue
} }
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
if fresh == nil {
continue
}
// 选择优先级最高且最久未使用的账号 // 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used // Select highest priority and least recently used
...@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
} }
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) { (requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if err == nil && result.Acquired { if account == nil {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return &AccountSelectionResult{ } else {
Account: account, result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
Acquired: true, if err == nil && result.Acquired {
ReleaseFunc: result.ReleaseFunc, _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
}, nil return &AccountSelectionResult{
} Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
AccountID: accountID, AccountID: accountID,
MaxConcurrency: account.Concurrency, MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout, Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting, MaxWaiting: cfg.StickySessionMaxWaiting,
}, },
}, nil }, nil
}
} }
} }
} }
...@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. ...@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return fresh return fresh
} }
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
return account
}
latest, err := s.accountRepo.GetByID(ctx, account.ID)
if err != nil || latest == nil {
return nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
return nil
}
return latest
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
var ( var (
account *Account account *Account
...@@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( ...@@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
} }
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
// Passthrough mode preserves the raw upstream error response, but runtime
// account state still needs to be updated so sticky routing can stop
// reusing a freshly rate-limited account.
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
......
...@@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF ...@@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
require.True(t, arr[len(arr)-1].Passthrough) require.True(t, arr[len(arr)-1].Passthrough)
} }
func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-rate-limit"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))),
}
upstream := &httpUpstreamRecorder{resp: resp}
repo := &openAIWSRateLimitSignalRepo{}
rateSvc := &RateLimitService{accountRepo: repo}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
rateLimitService: rateSvc,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Contains(t, rec.Body.String(), "usage_limit_reached")
require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
}
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
...@@ -29,9 +29,10 @@ type soraSessionChunk struct { ...@@ -29,9 +29,10 @@ type soraSessionChunk struct {
// OpenAIOAuthService handles OpenAI OAuth authentication flows // OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct { type OpenAIOAuthService struct {
sessionStore *openai.SessionStore sessionStore *openai.SessionStore
proxyRepo ProxyRepository proxyRepo ProxyRepository
oauthClient OpenAIOAuthClient oauthClient OpenAIOAuthClient
privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-api(ImpersonateChrome)
} }
// NewOpenAIOAuthService creates a new OpenAI OAuth service // NewOpenAIOAuthService creates a new OpenAI OAuth service
...@@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli ...@@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli
} }
} }
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
// 用于调用 chatgpt.com/backend-api 获取账号信息(plan_type 等)。
func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) {
s.privacyClientFactory = factory
}
// OpenAIAuthURLResult contains the authorization URL and session info // OpenAIAuthURLResult contains the authorization URL and session info
type OpenAIAuthURLResult struct { type OpenAIAuthURLResult struct {
AuthURL string `json:"auth_url"` AuthURL string `json:"auth_url"`
...@@ -131,6 +138,7 @@ type OpenAITokenInfo struct { ...@@ -131,6 +138,7 @@ type OpenAITokenInfo struct {
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"` PlanType string `json:"plan_type,omitempty"`
PrivacyMode string `json:"privacy_mode,omitempty"`
} }
// ExchangeCode exchanges authorization code for tokens // ExchangeCode exchanges authorization code for tokens
...@@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre ...@@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo.PlanType = userInfo.PlanType tokenInfo.PlanType = userInfo.PlanType
} }
// id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全
if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
// 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号
orgID := tokenInfo.OrganizationID
if orgID == "" {
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
orgID = atClaims.OpenAIAuth.POID
}
}
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
if tokenInfo.PlanType == "" && info.PlanType != "" {
tokenInfo.PlanType = info.PlanType
}
if tokenInfo.Email == "" && info.Email != "" {
tokenInfo.Email = info.Email
}
}
}
// 尝试设置隐私(关闭训练数据共享),best-effort
if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
return tokenInfo, nil return tokenInfo, nil
} }
......
...@@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto ...@@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
return PrivacyModeTrainingOff return PrivacyModeTrainingOff
} }
// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息
type ChatGPTAccountInfo struct {
PlanType string
Email string
}
const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.).
// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT).
// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team).
// Returns nil on any failure (best-effort, non-blocking).
func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, orgID string) *ChatGPTAccountInfo {
if accessToken == "" || clientFactory == nil {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
client, err := clientFactory(proxyURL)
if err != nil {
slog.Debug("chatgpt_account_check_client_error", "error", err.Error())
return nil
}
var result map[string]any
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Origin", "https://chatgpt.com").
SetHeader("Referer", "https://chatgpt.com/").
SetHeader("Accept", "application/json").
SetSuccessResult(&result).
Get(chatGPTAccountsCheckURL)
if err != nil {
slog.Debug("chatgpt_account_check_request_error", "error", err.Error())
return nil
}
if !resp.IsSuccessState() {
slog.Debug("chatgpt_account_check_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200))
return nil
}
info := &ChatGPTAccountInfo{}
accounts, ok := result["accounts"].(map[string]any)
if !ok {
slog.Debug("chatgpt_account_check_no_accounts", "body", truncate(resp.String(), 300))
return nil
}
// 优先匹配 orgID 对应的账号(access_token JWT 中的 poid)
if orgID != "" {
if matched := extractPlanFromAccount(accounts, orgID); matched != "" {
info.PlanType = matched
}
}
// 未匹配到时,遍历所有账号:优先 is_default,次选非 free
if info.PlanType == "" {
var defaultPlan, paidPlan, anyPlan string
for _, acctRaw := range accounts {
acct, ok := acctRaw.(map[string]any)
if !ok {
continue
}
planType := extractPlanType(acct)
if planType == "" {
continue
}
if anyPlan == "" {
anyPlan = planType
}
if account, ok := acct["account"].(map[string]any); ok {
if isDefault, _ := account["is_default"].(bool); isDefault {
defaultPlan = planType
}
}
if !strings.EqualFold(planType, "free") && paidPlan == "" {
paidPlan = planType
}
}
// 优先级:default > 非 free > 任意
switch {
case defaultPlan != "":
info.PlanType = defaultPlan
case paidPlan != "":
info.PlanType = paidPlan
default:
info.PlanType = anyPlan
}
}
if info.PlanType == "" {
slog.Debug("chatgpt_account_check_no_plan_type", "body", truncate(resp.String(), 300))
return nil
}
slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID)
return info
}
// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type
func extractPlanFromAccount(accounts map[string]any, accountKey string) string {
acctRaw, ok := accounts[accountKey]
if !ok {
return ""
}
acct, ok := acctRaw.(map[string]any)
if !ok {
return ""
}
return extractPlanType(acct)
}
// extractPlanType 从单个 account 对象中提取 plan_type
func extractPlanType(acct map[string]any) string {
if account, ok := acct["account"].(map[string]any); ok {
if planType, ok := account["plan_type"].(string); ok && planType != "" {
return planType
}
}
if entitlement, ok := acct["entitlement"].(map[string]any); ok {
if subPlan, ok := entitlement["subscription_plan"].(string); ok && subPlan != "" {
return subPlan
}
}
return ""
}
func truncate(s string, n int) string { func truncate(s string, n int) string {
if len(s) <= n { if len(s) <= n {
return s return s
......
...@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss( ...@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.Zero(t, boundAccountID) require.Zero(t, boundAccountID)
} }
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(24)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleAccount := &Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
dbAccount := Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
RateLimitResetAt: &rateLimitedUntil,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
snapshotCache := &openAISnapshotCacheStub{
accountsByID: map[int64]*Account{dbAccount.ID: staleAccount},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
require.NoError(t, getErr)
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(23) groupID := int64(23)
......
...@@ -3840,6 +3840,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( ...@@ -3840,6 +3840,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil return nil, nil
} }
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired { if acquireErr == nil && result.Acquired {
......
...@@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re ...@@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re
return nil return nil
} }
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
_ = platform _ = platform
_ = accountType _ = accountType
_ = status _ = status
_ = search _ = search
_ = groupID _ = groupID
_ = privacyMode
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
} }
...@@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount( ...@@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Len(t, accounts, 1) require.Len(t, accounts, 1)
......
...@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s ...@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page, Page: page,
PageSize: opsAccountsPageSize, PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "", 0) }, platformFilter, "", "", "", 0, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -62,6 +62,12 @@ type OpsErrorLog struct { ...@@ -62,6 +62,12 @@ type OpsErrorLog struct {
ClientIP *string `json:"client_ip"` ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"` RequestPath string `json:"request_path"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
InboundEndpoint string `json:"inbound_endpoint"`
UpstreamEndpoint string `json:"upstream_endpoint"`
RequestedModel string `json:"requested_model"`
UpstreamModel string `json:"upstream_model"`
RequestType *int16 `json:"request_type"`
} }
type OpsErrorLogDetail struct { type OpsErrorLogDetail struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment