Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
fa68cbad
Unverified
Commit
fa68cbad
authored
Mar 24, 2026
by
InCerryGit
Committed by
GitHub
Mar 24, 2026
Browse files
Merge branch 'Wei-Shaw:main' into main
parents
995ef134
0f033930
Changes
87
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/gateway_request_test.go
View file @
fa68cbad
...
...
@@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
require
.
NotEmpty
(
t
,
block1
[
"text"
])
}
func
TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult
(
t
*
testing
.
T
)
{
// Empty text blocks nested inside tool_result content should also be stripped
input
:=
[]
byte
(
`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":"valid result"},
{"type":"text","text":""}
]}
]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
:=
req
[
"messages"
]
.
([]
any
)
msg0
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
content0
:=
msg0
[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
content0
,
1
)
toolResult
:=
content0
[
0
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"tool_result"
,
toolResult
[
"type"
])
nestedContent
:=
toolResult
[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
nestedContent
,
1
)
require
.
Equal
(
t
,
"valid result"
,
nestedContent
[
0
]
.
(
map
[
string
]
any
)[
"text"
])
}
func
TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice
(
t
*
testing
.
T
)
{
// If all nested content blocks in tool_result are empty text, content becomes empty slice
input
:=
[]
byte
(
`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":""}
]},
{"type":"text","text":"hello"}
]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
:=
req
[
"messages"
]
.
([]
any
)
msg0
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
content0
:=
msg0
[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
content0
,
2
)
toolResult
:=
content0
[
0
]
.
(
map
[
string
]
any
)
nestedContent
:=
toolResult
[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
nestedContent
,
0
)
}
func
TestStripEmptyTextBlocks
(
t
*
testing
.
T
)
{
t
.
Run
(
"strips top-level empty text"
,
func
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`
)
out
:=
StripEmptyTextBlocks
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
:=
req
[
"messages"
]
.
([]
any
)
content
:=
msgs
[
0
]
.
(
map
[
string
]
any
)[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
content
,
1
)
require
.
Equal
(
t
,
"hello"
,
content
[
0
]
.
(
map
[
string
]
any
)[
"text"
])
})
t
.
Run
(
"strips nested empty text in tool_result"
,
func
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`
)
out
:=
StripEmptyTextBlocks
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
:=
req
[
"messages"
]
.
([]
any
)
content
:=
msgs
[
0
]
.
(
map
[
string
]
any
)[
"content"
]
.
([]
any
)
toolResult
:=
content
[
0
]
.
(
map
[
string
]
any
)
nestedContent
:=
toolResult
[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
nestedContent
,
1
)
require
.
Equal
(
t
,
"ok"
,
nestedContent
[
0
]
.
(
map
[
string
]
any
)[
"text"
])
})
t
.
Run
(
"no-op when no empty text"
,
func
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
)
out
:=
StripEmptyTextBlocks
(
input
)
require
.
Equal
(
t
,
input
,
out
)
})
t
.
Run
(
"preserves non-map blocks in content"
,
func
(
t
*
testing
.
T
)
{
// tool_result content can be a string; non-map blocks should pass through unchanged
input
:=
[]
byte
(
`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`
)
out
:=
StripEmptyTextBlocks
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
:=
req
[
"messages"
]
.
([]
any
)
content
:=
msgs
[
0
]
.
(
map
[
string
]
any
)[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
content
,
1
)
toolResult
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"tool_result"
,
toolResult
[
"type"
])
require
.
Equal
(
t
,
"string content"
,
toolResult
[
"content"
])
})
t
.
Run
(
"handles deeply nested tool_result"
,
func
(
t
*
testing
.
T
)
{
// Recursive: tool_result containing another tool_result with empty text
input
:=
[]
byte
(
`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`
)
out
:=
StripEmptyTextBlocks
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
:=
req
[
"messages"
]
.
([]
any
)
content
:=
msgs
[
0
]
.
(
map
[
string
]
any
)[
"content"
]
.
([]
any
)
outer
:=
content
[
0
]
.
(
map
[
string
]
any
)
innerContent
:=
outer
[
"content"
]
.
([]
any
)
inner
:=
innerContent
[
0
]
.
(
map
[
string
]
any
)
deepContent
:=
inner
[
"content"
]
.
([]
any
)
require
.
Len
(
t
,
deepContent
,
1
)
require
.
Equal
(
t
,
"deep"
,
deepContent
[
0
]
.
(
map
[
string
]
any
)[
"text"
])
})
}
func
TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks
(
t
*
testing
.
T
)
{
// Non-empty text blocks should pass through unchanged
input
:=
[]
byte
(
`{
...
...
backend/internal/service/gateway_service.go
View file @
fa68cbad
...
...
@@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if
parsed
.
SessionContext
!=
nil
{
_
,
_
=
combined
.
WriteString
(
parsed
.
SessionContext
.
ClientIP
)
_
,
_
=
combined
.
WriteString
(
":"
)
_
,
_
=
combined
.
WriteString
(
parsed
.
SessionContext
.
UserAgent
)
_
,
_
=
combined
.
WriteString
(
NormalizeSessionUserAgent
(
parsed
.
SessionContext
.
UserAgent
)
)
_
,
_
=
combined
.
WriteString
(
":"
)
_
,
_
=
combined
.
WriteString
(
strconv
.
FormatInt
(
parsed
.
SessionContext
.
APIKeyID
,
10
))
_
,
_
=
combined
.
WriteString
(
"|"
)
...
...
@@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 调试日志:记录即将转发的账号信息
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s"
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
account
.
Type
,
account
.
IsTLSFingerprintEnabled
(),
proxyURL
)
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body
=
StripEmptyTextBlocks
(
body
)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody
(
c
,
body
)
...
...
@@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Kind
:
"request_error"
,
Message
:
safeErr
,
})
...
...
@@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Kind
:
"signature_error"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
...
...
@@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
retryResp
.
StatusCode
,
UpstreamRequestID
:
retryResp
.
Header
.
Get
(
"x-request-id"
),
UpstreamURL
:
safeUpstreamURL
(
retryReq
.
URL
.
String
()),
Kind
:
"signature_retry_thinking"
,
Message
:
extractUpstreamErrorMessage
(
retryRespBody
),
Detail
:
func
()
string
{
...
...
@@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
UpstreamURL
:
safeUpstreamURL
(
retryReq2
.
URL
.
String
()),
Kind
:
"signature_retry_tools_request_error"
,
Message
:
sanitizeUpstreamErrorMessage
(
retryErr2
.
Error
()),
})
...
...
@@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Kind
:
"budget_constraint_error"
,
Message
:
errMsg
,
Detail
:
func
()
string
{
...
...
@@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Kind
:
"retry"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
...
...
@@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
if
c
!=
nil
{
c
.
Set
(
"anthropic_passthrough"
,
true
)
}
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
input
.
Body
=
StripEmptyTextBlocks
(
input
.
Body
)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody
(
c
,
input
.
Body
)
...
...
@@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Passthrough
:
true
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Passthrough
:
true
,
Kind
:
"retry"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
...
...
@@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Kind
:
"request_error"
,
Message
:
safeErr
,
})
...
...
@@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Kind
:
"retry"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
...
...
@@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
// Pre-filter: strip empty text blocks to prevent upstream 400.
body
=
StripEmptyTextBlocks
(
body
)
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
...
...
@@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Passthrough
:
true
,
Kind
:
"request_error"
,
Message
:
sanitizeUpstreamErrorMessage
(
err
.
Error
()),
...
...
@@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
UpstreamURL
:
safeUpstreamURL
(
upstreamReq
.
URL
.
String
()),
Passthrough
:
true
,
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
fa68cbad
...
...
@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func
(
m
*
mockAccountRepoForGemini
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
m
*
mockAccountRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
...
...
backend/internal/service/gemini_session.go
View file @
fa68cbad
...
...
@@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
// 返回 16 字符的 Base64 编码的 SHA256 前缀
func
GenerateGeminiPrefixHash
(
userID
,
apiKeyID
int64
,
ip
,
userAgent
,
platform
,
model
string
)
string
{
// 组合所有标识符
normalizedUserAgent
:=
NormalizeSessionUserAgent
(
userAgent
)
combined
:=
strconv
.
FormatInt
(
userID
,
10
)
+
":"
+
strconv
.
FormatInt
(
apiKeyID
,
10
)
+
":"
+
ip
+
":"
+
u
serAgent
+
":"
+
normalizedU
serAgent
+
":"
+
platform
+
":"
+
model
...
...
backend/internal/service/gemini_session_test.go
View file @
fa68cbad
...
...
@@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
}
}
func
TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise
(
t
*
testing
.
T
)
{
hash1
:=
GenerateGeminiPrefixHash
(
1
,
100
,
"192.168.1.1"
,
"Mozilla/5.0 codex_cli_rs/0.1.0"
,
"antigravity"
,
"gemini-2.5-pro"
)
hash2
:=
GenerateGeminiPrefixHash
(
1
,
100
,
"192.168.1.1"
,
"Mozilla/5.0 codex_cli_rs/0.1.1"
,
"antigravity"
,
"gemini-2.5-pro"
)
if
hash1
!=
hash2
{
t
.
Fatalf
(
"version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s"
,
hash1
,
hash2
)
}
}
func
TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise
(
t
*
testing
.
T
)
{
hash1
:=
GenerateGeminiPrefixHash
(
1
,
100
,
"192.168.1.1"
,
"Codex CLI 0.1.0"
,
"antigravity"
,
"gemini-2.5-pro"
)
hash2
:=
GenerateGeminiPrefixHash
(
1
,
100
,
"192.168.1.1"
,
"Codex CLI 0.1.1"
,
"antigravity"
,
"gemini-2.5-pro"
)
if
hash1
!=
hash2
{
t
.
Fatalf
(
"free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s"
,
hash1
,
hash2
)
}
}
func
TestParseGeminiSessionValue
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
...
...
backend/internal/service/gemini_token_provider.go
View file @
fa68cbad
...
...
@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
tierID
!=
""
{
account
.
Credentials
[
"tier_id"
]
=
tierID
}
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
_
=
p
ersistAccountCredentials
(
ctx
,
p
.
accountRepo
,
account
,
account
.
Credentials
)
}
}
...
...
backend/internal/service/generate_session_hash_test.go
View file @
fa68cbad
...
...
@@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
require
.
NotEqual
(
t
,
h1
,
h2
,
"different User-Agent should produce different hash"
)
}
func
TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
base
:=
func
(
ua
string
)
*
ParsedRequest
{
return
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"test"
},
},
SessionContext
:
&
SessionContext
{
ClientIP
:
"1.1.1.1"
,
UserAgent
:
ua
,
APIKeyID
:
1
,
},
}
}
h1
:=
svc
.
GenerateSessionHash
(
base
(
"Mozilla/5.0 codex_cli_rs/0.1.0"
))
h2
:=
svc
.
GenerateSessionHash
(
base
(
"Mozilla/5.0 codex_cli_rs/0.1.1"
))
require
.
Equal
(
t
,
h1
,
h2
,
"version-only User-Agent changes should not perturb the sticky session hash"
)
}
func
TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
base
:=
func
(
ua
string
)
*
ParsedRequest
{
return
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"test"
},
},
SessionContext
:
&
SessionContext
{
ClientIP
:
"1.1.1.1"
,
UserAgent
:
ua
,
APIKeyID
:
1
,
},
}
}
h1
:=
svc
.
GenerateSessionHash
(
base
(
"Codex CLI 0.1.0"
))
h2
:=
svc
.
GenerateSessionHash
(
base
(
"Codex CLI 0.1.1"
))
require
.
Equal
(
t
,
h1
,
h2
,
"free-form version-only User-Agent changes should not perturb the sticky session hash"
)
}
func
TestGenerateSessionHash_SessionContext_APIKeyIDDifference
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
...
...
backend/internal/service/oauth_refresh_api.go
View file @
fa68cbad
...
...
@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB
if
newCredentials
!=
nil
{
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
freshAccount
.
Credentials
=
newCredentials
if
updateErr
:=
api
.
accountRepo
.
Update
(
ctx
,
freshAccount
);
updateErr
!=
nil
{
if
updateErr
:=
persistAccountCredentials
(
ctx
,
api
.
accountRepo
,
freshAccount
,
newCredentials
);
updateErr
!=
nil
{
slog
.
Error
(
"oauth_refresh_update_failed"
,
"account_id"
,
freshAccount
.
ID
,
"error"
,
updateErr
,
...
...
backend/internal/service/oauth_refresh_api_test.go
View file @
fa68cbad
...
...
@@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type
refreshAPIAccountRepo
struct
{
mockAccountRepoForGemini
account
*
Account
// returned by GetByID
getByIDErr
error
updateErr
error
updateCalls
int
account
*
Account
// returned by GetByID
getByIDErr
error
updateErr
error
updateCalls
int
updateCredentialsCalls
int
}
func
(
r
*
refreshAPIAccountRepo
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
Account
,
error
)
{
...
...
@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return
r
.
updateErr
}
func
(
r
*
refreshAPIAccountRepo
)
UpdateCredentials
(
_
context
.
Context
,
id
int64
,
credentials
map
[
string
]
any
)
error
{
r
.
updateCalls
++
r
.
updateCredentialsCalls
++
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
if
r
.
account
==
nil
||
r
.
account
.
ID
!=
id
{
r
.
account
=
&
Account
{
ID
:
id
}
}
r
.
account
.
Credentials
=
cloneCredentials
(
credentials
)
return
nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type
refreshAPIExecutorStub
struct
{
needsRefresh
bool
...
...
@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require
.
Equal
(
t
,
"new-token"
,
result
.
NewCredentials
[
"access_token"
])
require
.
NotNil
(
t
,
result
.
NewCredentials
[
"_token_version"
])
// version stamp set
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB updated
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
}
func
TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState
(
t
*
testing
.
T
)
{
resetAt
:=
time
.
Now
()
.
Add
(
45
*
time
.
Minute
)
account
:=
&
Account
{
ID
:
11
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
RateLimitResetAt
:
&
resetAt
,
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"safe-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
Equal
(
t
,
1
,
repo
.
updateCredentialsCalls
)
require
.
NotNil
(
t
,
repo
.
account
.
RateLimitResetAt
)
require
.
WithinDuration
(
t
,
resetAt
,
*
repo
.
account
.
RateLimitResetAt
,
time
.
Second
)
}
func
TestRefreshIfNeeded_LockHeld
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformAnthropic
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
...
...
@@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid_grant"
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock still released via defer
}
...
...
@@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result
:=
MergeCredentials
(
old
,
new
)
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
}
// ========== BuildClaudeAccountCredentials tests ==========
...
...
backend/internal/service/openai_account_scheduler.go
View file @
fa68cbad
...
...
@@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
return
nil
,
nil
}
account
=
s
.
service
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
req
.
RequestedModel
)
if
account
==
nil
{
_
=
s
.
service
.
deleteStickySessionAccountID
(
ctx
,
req
.
GroupID
,
sessionHash
)
return
nil
,
nil
}
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
acquireErr
==
nil
&&
result
.
Acquired
{
...
...
@@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
continue
}
fresh
=
s
.
service
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
fresh
,
req
.
RequestedModel
)
if
fresh
==
nil
||
!
s
.
isAccountTransportCompatible
(
fresh
,
req
.
RequiredTransport
)
{
continue
}
result
,
acquireErr
:=
s
.
service
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
if
acquireErr
!=
nil
{
return
nil
,
len
(
candidates
),
topK
,
loadSkew
,
acquireErr
...
...
backend/internal/service/openai_account_scheduler_test.go
View file @
fa68cbad
...
...
@@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require
.
Equal
(
t
,
int64
(
32002
),
account
.
ID
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10103
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
staleSticky
:=
&
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
}
staleBackup
:=
&
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbSticky
:=
Account
{
ID
:
33001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbBackup
:=
Account
{
ID
:
33002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
cache
:=
&
stubGatewayCache
{
sessionBindings
:
map
[
string
]
int64
{
"openai:session_hash_db_runtime_recheck"
:
33001
}}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
staleSticky
,
staleBackup
},
accountsByID
:
map
[
int64
]
*
Account
{
33001
:
staleSticky
,
33002
:
staleBackup
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbSticky
,
dbBackup
}},
cache
:
cache
,
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
}
selection
,
decision
,
err
:=
svc
.
SelectAccountWithScheduler
(
ctx
,
&
groupID
,
""
,
"session_hash_db_runtime_recheck"
,
"gpt-5.1"
,
nil
,
OpenAIUpstreamTransportAny
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
selection
)
require
.
NotNil
(
t
,
selection
.
Account
)
require
.
Equal
(
t
,
int64
(
33002
),
selection
.
Account
.
ID
)
require
.
Equal
(
t
,
openAIAccountScheduleLayerLoadBalance
,
decision
.
Layer
)
}
func
TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10104
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
stalePrimary
:=
&
Account
{
ID
:
34001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
}
staleSecondary
:=
&
Account
{
ID
:
34002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
dbPrimary
:=
Account
{
ID
:
34001
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
0
,
RateLimitResetAt
:
&
rateLimitedUntil
}
dbSecondary
:=
Account
{
ID
:
34002
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
5
}
snapshotCache
:=
&
openAISnapshotCacheStub
{
snapshotAccounts
:
[]
*
Account
{
stalePrimary
,
staleSecondary
},
accountsByID
:
map
[
int64
]
*
Account
{
34001
:
stalePrimary
,
34002
:
staleSecondary
},
}
snapshotService
:=
&
SchedulerSnapshotService
{
cache
:
snapshotCache
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbPrimary
,
dbSecondary
}},
cfg
:
&
config
.
Config
{},
schedulerSnapshot
:
snapshotService
,
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
int64
(
34002
),
account
.
ID
)
}
func
TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
9
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
fa68cbad
...
...
@@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
return
nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
...
...
@@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if
fresh
==
nil
{
continue
}
fresh
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
fresh
,
requestedModel
)
if
fresh
==
nil
{
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
...
...
@@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
if
!
clearSticky
&&
account
.
IsSchedulable
()
&&
account
.
IsOpenAI
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
}
else
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
refreshStickySessionTTL
(
ctx
,
groupID
,
sessionHash
,
openaiStickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
accountID
,
MaxConcurrency
:
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
accountID
,
MaxConcurrency
:
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
}
}
}
...
...
@@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return
fresh
}
func
(
s
*
OpenAIGatewayService
)
recheckSelectedOpenAIAccountFromDB
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
*
Account
{
if
account
==
nil
{
return
nil
}
if
s
.
schedulerSnapshot
==
nil
||
s
.
accountRepo
==
nil
{
return
account
}
latest
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
latest
==
nil
{
return
nil
}
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
latest
,
time
.
Now
())
if
!
latest
.
IsSchedulable
()
||
!
latest
.
IsOpenAI
()
{
return
nil
}
if
requestedModel
!=
""
&&
!
latest
.
IsModelSupported
(
requestedModel
)
{
return
nil
}
return
latest
}
func
(
s
*
OpenAIGatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
var
(
account
*
Account
...
...
@@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
logOpenAIInstructionsRequiredDebug
(
ctx
,
c
,
account
,
resp
.
StatusCode
,
upstreamMsg
,
requestBody
,
body
)
if
s
.
rateLimitService
!=
nil
{
// Passthrough mode preserves the raw upstream error response, but runtime
// account state still needs to be updated so sticky routing can stop
// reusing a freshly rate-limited account.
_
=
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
...
...
backend/internal/service/openai_oauth_passthrough_test.go
View file @
fa68cbad
...
...
@@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
require
.
True
(
t
,
arr
[
len
(
arr
)
-
1
]
.
Passthrough
)
}
func
TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/responses"
,
bytes
.
NewReader
(
nil
))
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.1.0"
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`
)
resetAt
:=
time
.
Now
()
.
Add
(
7
*
24
*
time
.
Hour
)
.
Unix
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"x-request-id"
:
[]
string
{
"rid-rate-limit"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
fmt
.
Sprintf
(
`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`
,
resetAt
))),
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
resp
}
repo
:=
&
openAIWSRateLimitSignalRepo
{}
rateSvc
:=
&
RateLimitService
{
accountRepo
:
repo
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
ForceCodexCLI
:
false
}},
httpUpstream
:
upstream
,
rateLimitService
:
rateSvc
,
}
account
:=
&
Account
{
ID
:
123
,
Name
:
"acc"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
"chatgpt_account_id"
:
"chatgpt-acc"
},
Extra
:
map
[
string
]
any
{
"openai_passthrough"
:
true
},
Status
:
StatusActive
,
Schedulable
:
true
,
RateMultiplier
:
f64p
(
1
),
}
_
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
originalBody
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"usage_limit_reached"
)
require
.
Len
(
t
,
repo
.
rateLimitCalls
,
1
)
require
.
WithinDuration
(
t
,
time
.
Unix
(
resetAt
,
0
),
repo
.
rateLimitCalls
[
0
],
2
*
time
.
Second
)
}
func
TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/service/openai_oauth_service.go
View file @
fa68cbad
...
...
@@ -29,9 +29,10 @@ type soraSessionChunk struct {
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type
OpenAIOAuthService
struct
{
sessionStore
*
openai
.
SessionStore
proxyRepo
ProxyRepository
oauthClient
OpenAIOAuthClient
sessionStore
*
openai
.
SessionStore
proxyRepo
ProxyRepository
oauthClient
OpenAIOAuthClient
privacyClientFactory
PrivacyClientFactory
// 用于调用 chatgpt.com/backend-api(ImpersonateChrome)
}
// NewOpenAIOAuthService creates a new OpenAI OAuth service
...
...
@@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli
}
}
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
// 用于调用 chatgpt.com/backend-api 获取账号信息(plan_type 等)。
func
(
s
*
OpenAIOAuthService
)
SetPrivacyClientFactory
(
factory
PrivacyClientFactory
)
{
s
.
privacyClientFactory
=
factory
}
// OpenAIAuthURLResult contains the authorization URL and session info
type
OpenAIAuthURLResult
struct
{
AuthURL
string
`json:"auth_url"`
...
...
@@ -131,6 +138,7 @@ type OpenAITokenInfo struct {
ChatGPTUserID
string
`json:"chatgpt_user_id,omitempty"`
OrganizationID
string
`json:"organization_id,omitempty"`
PlanType
string
`json:"plan_type,omitempty"`
PrivacyMode
string
`json:"privacy_mode,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
...
...
@@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo
.
PlanType
=
userInfo
.
PlanType
}
// id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全
if
tokenInfo
.
PlanType
==
""
&&
tokenInfo
.
AccessToken
!=
""
&&
s
.
privacyClientFactory
!=
nil
{
// 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号
orgID
:=
tokenInfo
.
OrganizationID
if
orgID
==
""
{
if
atClaims
,
err
:=
openai
.
DecodeIDToken
(
tokenInfo
.
AccessToken
);
err
==
nil
&&
atClaims
.
OpenAIAuth
!=
nil
{
orgID
=
atClaims
.
OpenAIAuth
.
POID
}
}
if
info
:=
fetchChatGPTAccountInfo
(
ctx
,
s
.
privacyClientFactory
,
tokenInfo
.
AccessToken
,
proxyURL
,
orgID
);
info
!=
nil
{
if
tokenInfo
.
PlanType
==
""
&&
info
.
PlanType
!=
""
{
tokenInfo
.
PlanType
=
info
.
PlanType
}
if
tokenInfo
.
Email
==
""
&&
info
.
Email
!=
""
{
tokenInfo
.
Email
=
info
.
Email
}
}
}
// 尝试设置隐私(关闭训练数据共享),best-effort
if
tokenInfo
.
AccessToken
!=
""
&&
s
.
privacyClientFactory
!=
nil
{
tokenInfo
.
PrivacyMode
=
disableOpenAITraining
(
ctx
,
s
.
privacyClientFactory
,
tokenInfo
.
AccessToken
,
proxyURL
)
}
return
tokenInfo
,
nil
}
...
...
backend/internal/service/openai_privacy_service.go
View file @
fa68cbad
...
...
@@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
return
PrivacyModeTrainingOff
}
// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息
type
ChatGPTAccountInfo
struct
{
PlanType
string
Email
string
}
const
chatGPTAccountsCheckURL
=
"https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.).
// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT).
// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team).
// Returns nil on any failure (best-effort, non-blocking).
func
fetchChatGPTAccountInfo
(
ctx
context
.
Context
,
clientFactory
PrivacyClientFactory
,
accessToken
,
proxyURL
,
orgID
string
)
*
ChatGPTAccountInfo
{
if
accessToken
==
""
||
clientFactory
==
nil
{
return
nil
}
ctx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
15
*
time
.
Second
)
defer
cancel
()
client
,
err
:=
clientFactory
(
proxyURL
)
if
err
!=
nil
{
slog
.
Debug
(
"chatgpt_account_check_client_error"
,
"error"
,
err
.
Error
())
return
nil
}
var
result
map
[
string
]
any
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Authorization"
,
"Bearer "
+
accessToken
)
.
SetHeader
(
"Origin"
,
"https://chatgpt.com"
)
.
SetHeader
(
"Referer"
,
"https://chatgpt.com/"
)
.
SetHeader
(
"Accept"
,
"application/json"
)
.
SetSuccessResult
(
&
result
)
.
Get
(
chatGPTAccountsCheckURL
)
if
err
!=
nil
{
slog
.
Debug
(
"chatgpt_account_check_request_error"
,
"error"
,
err
.
Error
())
return
nil
}
if
!
resp
.
IsSuccessState
()
{
slog
.
Debug
(
"chatgpt_account_check_failed"
,
"status"
,
resp
.
StatusCode
,
"body"
,
truncate
(
resp
.
String
(),
200
))
return
nil
}
info
:=
&
ChatGPTAccountInfo
{}
accounts
,
ok
:=
result
[
"accounts"
]
.
(
map
[
string
]
any
)
if
!
ok
{
slog
.
Debug
(
"chatgpt_account_check_no_accounts"
,
"body"
,
truncate
(
resp
.
String
(),
300
))
return
nil
}
// 优先匹配 orgID 对应的账号(access_token JWT 中的 poid)
if
orgID
!=
""
{
if
matched
:=
extractPlanFromAccount
(
accounts
,
orgID
);
matched
!=
""
{
info
.
PlanType
=
matched
}
}
// 未匹配到时,遍历所有账号:优先 is_default,次选非 free
if
info
.
PlanType
==
""
{
var
defaultPlan
,
paidPlan
,
anyPlan
string
for
_
,
acctRaw
:=
range
accounts
{
acct
,
ok
:=
acctRaw
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
planType
:=
extractPlanType
(
acct
)
if
planType
==
""
{
continue
}
if
anyPlan
==
""
{
anyPlan
=
planType
}
if
account
,
ok
:=
acct
[
"account"
]
.
(
map
[
string
]
any
);
ok
{
if
isDefault
,
_
:=
account
[
"is_default"
]
.
(
bool
);
isDefault
{
defaultPlan
=
planType
}
}
if
!
strings
.
EqualFold
(
planType
,
"free"
)
&&
paidPlan
==
""
{
paidPlan
=
planType
}
}
// 优先级:default > 非 free > 任意
switch
{
case
defaultPlan
!=
""
:
info
.
PlanType
=
defaultPlan
case
paidPlan
!=
""
:
info
.
PlanType
=
paidPlan
default
:
info
.
PlanType
=
anyPlan
}
}
if
info
.
PlanType
==
""
{
slog
.
Debug
(
"chatgpt_account_check_no_plan_type"
,
"body"
,
truncate
(
resp
.
String
(),
300
))
return
nil
}
slog
.
Info
(
"chatgpt_account_check_success"
,
"plan_type"
,
info
.
PlanType
,
"org_id"
,
orgID
)
return
info
}
// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type
func
extractPlanFromAccount
(
accounts
map
[
string
]
any
,
accountKey
string
)
string
{
acctRaw
,
ok
:=
accounts
[
accountKey
]
if
!
ok
{
return
""
}
acct
,
ok
:=
acctRaw
.
(
map
[
string
]
any
)
if
!
ok
{
return
""
}
return
extractPlanType
(
acct
)
}
// extractPlanType 从单个 account 对象中提取 plan_type
func
extractPlanType
(
acct
map
[
string
]
any
)
string
{
if
account
,
ok
:=
acct
[
"account"
]
.
(
map
[
string
]
any
);
ok
{
if
planType
,
ok
:=
account
[
"plan_type"
]
.
(
string
);
ok
&&
planType
!=
""
{
return
planType
}
}
if
entitlement
,
ok
:=
acct
[
"entitlement"
]
.
(
map
[
string
]
any
);
ok
{
if
subPlan
,
ok
:=
entitlement
[
"subscription_plan"
]
.
(
string
);
ok
&&
subPlan
!=
""
{
return
subPlan
}
}
return
""
}
func
truncate
(
s
string
,
n
int
)
string
{
if
len
(
s
)
<=
n
{
return
s
...
...
backend/internal/service/openai_ws_account_sticky_test.go
View file @
fa68cbad
...
...
@@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require
.
Zero
(
t
,
boundAccountID
)
}
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
24
)
rateLimitedUntil
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
staleAccount
:=
&
Account
{
ID
:
13
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
dbAccount
:=
Account
{
ID
:
13
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
RateLimitResetAt
:
&
rateLimitedUntil
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
cache
:=
&
stubGatewayCache
{}
store
:=
NewOpenAIWSStateStore
(
cache
)
cfg
:=
newOpenAIWSV2TestConfig
()
snapshotCache
:=
&
openAISnapshotCacheStub
{
accountsByID
:
map
[
int64
]
*
Account
{
dbAccount
.
ID
:
staleAccount
},
}
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
dbAccount
}},
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
stubConcurrencyCache
{}),
openaiWSStateStore
:
store
,
schedulerSnapshot
:
&
SchedulerSnapshotService
{
cache
:
snapshotCache
},
}
require
.
NoError
(
t
,
store
.
BindResponseAccount
(
ctx
,
groupID
,
"resp_prev_db_rl"
,
dbAccount
.
ID
,
time
.
Hour
))
selection
,
err
:=
svc
.
SelectAccountByPreviousResponseID
(
ctx
,
&
groupID
,
"resp_prev_db_rl"
,
"gpt-5.1"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
selection
,
"DB 中已限流的账号不应继续命中 previous_response_id 粘连"
)
boundAccountID
,
getErr
:=
store
.
GetResponseAccount
(
ctx
,
groupID
,
"resp_prev_db_rl"
)
require
.
NoError
(
t
,
getErr
)
require
.
Zero
(
t
,
boundAccountID
)
}
func
TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
23
)
...
...
backend/internal/service/openai_ws_forwarder.go
View file @
fa68cbad
...
...
@@ -3840,6 +3840,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if
requestedModel
!=
""
&&
!
account
.
IsModelSupported
(
requestedModel
)
{
return
nil
,
nil
}
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
_
=
store
.
DeleteResponseAccount
(
ctx
,
derefGroupID
(
groupID
),
responseID
)
return
nil
,
nil
}
result
,
acquireErr
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
acquireErr
==
nil
&&
result
.
Acquired
{
...
...
backend/internal/service/openai_ws_ratelimit_signal_test.go
View file @
fa68cbad
...
...
@@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re
return
nil
}
func
(
r
*
openAICodexExtraListRepo
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
openAICodexExtraListRepo
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
_
=
platform
_
=
accountType
_
=
status
_
=
search
_
=
groupID
_
=
privacyMode
return
r
.
accounts
,
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
r
.
accounts
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
},
nil
}
...
...
@@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformOpenAI
,
AccountTypeOAuth
,
""
,
""
,
0
)
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformOpenAI
,
AccountTypeOAuth
,
""
,
""
,
0
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Len
(
t
,
accounts
,
1
)
...
...
backend/internal/service/ops_concurrency.go
View file @
fa68cbad
...
...
@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts
,
pageInfo
,
err
:=
s
.
accountRepo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
opsAccountsPageSize
,
},
platformFilter
,
""
,
""
,
""
,
0
)
},
platformFilter
,
""
,
""
,
""
,
0
,
""
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
backend/internal/service/ops_models.go
View file @
fa68cbad
...
...
@@ -62,6 +62,12 @@ type OpsErrorLog struct {
ClientIP
*
string
`json:"client_ip"`
RequestPath
string
`json:"request_path"`
Stream
bool
`json:"stream"`
InboundEndpoint
string
`json:"inbound_endpoint"`
UpstreamEndpoint
string
`json:"upstream_endpoint"`
RequestedModel
string
`json:"requested_model"`
UpstreamModel
string
`json:"upstream_model"`
RequestType
*
int16
`json:"request_type"`
}
type
OpsErrorLogDetail
struct
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment