Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
d367d1cd
Commit
d367d1cd
authored
Feb 09, 2026
by
yangjianbo
Browse files
Merge branch 'main' into test-sora
parents
d7011163
3c46f7d2
Changes
104
Show whitespace changes
Inline
Side-by-side
backend/internal/service/anthropic_session.go
0 → 100644
View file @
d367d1cd
package
service
import
(
"encoding/json"
"strings"
"time"
)
// Anthropic 会话 Fallback 相关常量
const
(
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
anthropicSessionTTLSeconds
=
300
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix
=
"anthropic:digest:"
)
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
func
AnthropicSessionTTL
()
time
.
Duration
{
return
anthropicSessionTTLSeconds
*
time
.
Second
}
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
// s = system, u = user, a = assistant
func
BuildAnthropicDigestChain
(
parsed
*
ParsedRequest
)
string
{
if
parsed
==
nil
{
return
""
}
var
parts
[]
string
// 1. system prompt
if
parsed
.
System
!=
nil
{
systemData
,
_
:=
json
.
Marshal
(
parsed
.
System
)
if
len
(
systemData
)
>
0
&&
string
(
systemData
)
!=
"null"
{
parts
=
append
(
parts
,
"s:"
+
shortHash
(
systemData
))
}
}
// 2. messages
for
_
,
msg
:=
range
parsed
.
Messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
role
,
_
:=
msgMap
[
"role"
]
.
(
string
)
prefix
:=
rolePrefix
(
role
)
content
:=
msgMap
[
"content"
]
contentData
,
_
:=
json
.
Marshal
(
content
)
parts
=
append
(
parts
,
prefix
+
":"
+
shortHash
(
contentData
))
}
return
strings
.
Join
(
parts
,
"-"
)
}
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
func
rolePrefix
(
role
string
)
string
{
switch
role
{
case
"assistant"
:
return
"a"
default
:
return
"u"
}
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func
GenerateAnthropicDigestSessionKey
(
prefixHash
,
uuid
string
)
string
{
prefix
:=
prefixHash
if
len
(
prefixHash
)
>=
8
{
prefix
=
prefixHash
[
:
8
]
}
uuidPart
:=
uuid
if
len
(
uuid
)
>=
8
{
uuidPart
=
uuid
[
:
8
]
}
return
anthropicDigestSessionKeyPrefix
+
prefix
+
":"
+
uuidPart
}
backend/internal/service/anthropic_session_test.go
0 → 100644
View file @
d367d1cd
package
service
import
(
"strings"
"testing"
)
func
TestBuildAnthropicDigestChain_NilRequest
(
t
*
testing
.
T
)
{
result
:=
BuildAnthropicDigestChain
(
nil
)
if
result
!=
""
{
t
.
Errorf
(
"expected empty string for nil request, got: %s"
,
result
)
}
}
func
TestBuildAnthropicDigestChain_EmptyMessages
(
t
*
testing
.
T
)
{
parsed
:=
&
ParsedRequest
{
Messages
:
[]
any
{},
}
result
:=
BuildAnthropicDigestChain
(
parsed
)
if
result
!=
""
{
t
.
Errorf
(
"expected empty string for empty messages, got: %s"
,
result
)
}
}
func
TestBuildAnthropicDigestChain_SingleUserMessage
(
t
*
testing
.
T
)
{
parsed
:=
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
}
result
:=
BuildAnthropicDigestChain
(
parsed
)
parts
:=
splitChain
(
result
)
if
len
(
parts
)
!=
1
{
t
.
Fatalf
(
"expected 1 part, got %d: %s"
,
len
(
parts
),
result
)
}
if
!
strings
.
HasPrefix
(
parts
[
0
],
"u:"
)
{
t
.
Errorf
(
"expected prefix 'u:', got: %s"
,
parts
[
0
])
}
}
func
TestBuildAnthropicDigestChain_UserAndAssistant
(
t
*
testing
.
T
)
{
parsed
:=
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"hi there"
},
},
}
result
:=
BuildAnthropicDigestChain
(
parsed
)
parts
:=
splitChain
(
result
)
if
len
(
parts
)
!=
2
{
t
.
Fatalf
(
"expected 2 parts, got %d: %s"
,
len
(
parts
),
result
)
}
if
!
strings
.
HasPrefix
(
parts
[
0
],
"u:"
)
{
t
.
Errorf
(
"part[0] expected prefix 'u:', got: %s"
,
parts
[
0
])
}
if
!
strings
.
HasPrefix
(
parts
[
1
],
"a:"
)
{
t
.
Errorf
(
"part[1] expected prefix 'a:', got: %s"
,
parts
[
1
])
}
}
func
TestBuildAnthropicDigestChain_WithSystemString
(
t
*
testing
.
T
)
{
parsed
:=
&
ParsedRequest
{
System
:
"You are a helpful assistant"
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
}
result
:=
BuildAnthropicDigestChain
(
parsed
)
parts
:=
splitChain
(
result
)
if
len
(
parts
)
!=
2
{
t
.
Fatalf
(
"expected 2 parts (s + u), got %d: %s"
,
len
(
parts
),
result
)
}
if
!
strings
.
HasPrefix
(
parts
[
0
],
"s:"
)
{
t
.
Errorf
(
"part[0] expected prefix 's:', got: %s"
,
parts
[
0
])
}
if
!
strings
.
HasPrefix
(
parts
[
1
],
"u:"
)
{
t
.
Errorf
(
"part[1] expected prefix 'u:', got: %s"
,
parts
[
1
])
}
}
func
TestBuildAnthropicDigestChain_WithSystemContentBlocks
(
t
*
testing
.
T
)
{
parsed
:=
&
ParsedRequest
{
System
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"You are a helpful assistant"
},
},
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
}
result
:=
BuildAnthropicDigestChain
(
parsed
)
parts
:=
splitChain
(
result
)
if
len
(
parts
)
!=
2
{
t
.
Fatalf
(
"expected 2 parts (s + u), got %d: %s"
,
len
(
parts
),
result
)
}
if
!
strings
.
HasPrefix
(
parts
[
0
],
"s:"
)
{
t
.
Errorf
(
"part[0] expected prefix 's:', got: %s"
,
parts
[
0
])
}
}
func
TestBuildAnthropicDigestChain_ConversationPrefixRelationship
(
t
*
testing
.
T
)
{
// 核心测试:验证对话增长时链的前缀关系
// 上一轮的完整链一定是下一轮链的前缀
system
:=
"You are a helpful assistant"
// 第 1 轮: system + user
round1
:=
&
ParsedRequest
{
System
:
system
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
}
chain1
:=
BuildAnthropicDigestChain
(
round1
)
// 第 2 轮: system + user + assistant + user
round2
:=
&
ParsedRequest
{
System
:
system
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"hi there"
},
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"how are you?"
},
},
}
chain2
:=
BuildAnthropicDigestChain
(
round2
)
// 第 3 轮: system + user + assistant + user + assistant + user
round3
:=
&
ParsedRequest
{
System
:
system
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"hi there"
},
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"how are you?"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"I'm doing well"
},
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"great"
},
},
}
chain3
:=
BuildAnthropicDigestChain
(
round3
)
t
.
Logf
(
"Chain1: %s"
,
chain1
)
t
.
Logf
(
"Chain2: %s"
,
chain2
)
t
.
Logf
(
"Chain3: %s"
,
chain3
)
// chain1 是 chain2 的前缀
if
!
strings
.
HasPrefix
(
chain2
,
chain1
)
{
t
.
Errorf
(
"chain1 should be prefix of chain2:
\n
chain1: %s
\n
chain2: %s"
,
chain1
,
chain2
)
}
// chain2 是 chain3 的前缀
if
!
strings
.
HasPrefix
(
chain3
,
chain2
)
{
t
.
Errorf
(
"chain2 should be prefix of chain3:
\n
chain2: %s
\n
chain3: %s"
,
chain2
,
chain3
)
}
// chain1 也是 chain3 的前缀(传递性)
if
!
strings
.
HasPrefix
(
chain3
,
chain1
)
{
t
.
Errorf
(
"chain1 should be prefix of chain3:
\n
chain1: %s
\n
chain3: %s"
,
chain1
,
chain3
)
}
}
func
TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain
(
t
*
testing
.
T
)
{
parsed1
:=
&
ParsedRequest
{
System
:
"System A"
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
}
parsed2
:=
&
ParsedRequest
{
System
:
"System B"
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
}
chain1
:=
BuildAnthropicDigestChain
(
parsed1
)
chain2
:=
BuildAnthropicDigestChain
(
parsed2
)
if
chain1
==
chain2
{
t
.
Error
(
"Different system prompts should produce different chains"
)
}
// 但 user 部分的 hash 应该相同
parts1
:=
splitChain
(
chain1
)
parts2
:=
splitChain
(
chain2
)
if
parts1
[
1
]
!=
parts2
[
1
]
{
t
.
Error
(
"Same user message should produce same hash regardless of system"
)
}
}
func
TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain
(
t
*
testing
.
T
)
{
parsed1
:=
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"ORIGINAL reply"
},
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"next"
},
},
}
parsed2
:=
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"TAMPERED reply"
},
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"next"
},
},
}
chain1
:=
BuildAnthropicDigestChain
(
parsed1
)
chain2
:=
BuildAnthropicDigestChain
(
parsed2
)
if
chain1
==
chain2
{
t
.
Error
(
"Different content should produce different chains"
)
}
parts1
:=
splitChain
(
chain1
)
parts2
:=
splitChain
(
chain2
)
// 第一个 user message hash 应该相同
if
parts1
[
0
]
!=
parts2
[
0
]
{
t
.
Error
(
"First user message hash should be the same"
)
}
// assistant reply hash 应该不同
if
parts1
[
1
]
==
parts2
[
1
]
{
t
.
Error
(
"Assistant reply hash should differ"
)
}
}
func
TestBuildAnthropicDigestChain_Deterministic
(
t
*
testing
.
T
)
{
parsed
:=
&
ParsedRequest
{
System
:
"test system"
,
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
"hi"
},
},
}
chain1
:=
BuildAnthropicDigestChain
(
parsed
)
chain2
:=
BuildAnthropicDigestChain
(
parsed
)
if
chain1
!=
chain2
{
t
.
Errorf
(
"BuildAnthropicDigestChain not deterministic: %s vs %s"
,
chain1
,
chain2
)
}
}
func
TestGenerateAnthropicDigestSessionKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
prefixHash
string
uuid
string
want
string
}{
{
name
:
"normal 16 char hash with uuid"
,
prefixHash
:
"abcdefgh12345678"
,
uuid
:
"550e8400-e29b-41d4-a716-446655440000"
,
want
:
"anthropic:digest:abcdefgh:550e8400"
,
},
{
name
:
"exactly 8 chars"
,
prefixHash
:
"12345678"
,
uuid
:
"abcdefgh"
,
want
:
"anthropic:digest:12345678:abcdefgh"
,
},
{
name
:
"short values"
,
prefixHash
:
"abc"
,
uuid
:
"xyz"
,
want
:
"anthropic:digest:abc:xyz"
,
},
{
name
:
"empty values"
,
prefixHash
:
""
,
uuid
:
""
,
want
:
"anthropic:digest::"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
GenerateAnthropicDigestSessionKey
(
tt
.
prefixHash
,
tt
.
uuid
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q"
,
tt
.
prefixHash
,
tt
.
uuid
,
got
,
tt
.
want
)
}
})
}
// 验证不同 uuid 产生不同 sessionKey
t
.
Run
(
"different uuid different key"
,
func
(
t
*
testing
.
T
)
{
hash
:=
"sameprefix123456"
result1
:=
GenerateAnthropicDigestSessionKey
(
hash
,
"uuid0001-session-a"
)
result2
:=
GenerateAnthropicDigestSessionKey
(
hash
,
"uuid0002-session-b"
)
if
result1
==
result2
{
t
.
Errorf
(
"Different UUIDs should produce different session keys: %s vs %s"
,
result1
,
result2
)
}
})
}
func
TestAnthropicSessionTTL
(
t
*
testing
.
T
)
{
ttl
:=
AnthropicSessionTTL
()
if
ttl
.
Seconds
()
!=
300
{
t
.
Errorf
(
"expected 300 seconds, got: %v"
,
ttl
.
Seconds
())
}
}
func
TestBuildAnthropicDigestChain_ContentBlocks
(
t
*
testing
.
T
)
{
// 测试 content 为 content blocks 数组的情况
parsed
:=
&
ParsedRequest
{
Messages
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"describe this image"
},
map
[
string
]
any
{
"type"
:
"image"
,
"source"
:
map
[
string
]
any
{
"type"
:
"base64"
}},
},
},
},
}
result
:=
BuildAnthropicDigestChain
(
parsed
)
parts
:=
splitChain
(
result
)
if
len
(
parts
)
!=
1
{
t
.
Fatalf
(
"expected 1 part, got %d: %s"
,
len
(
parts
),
result
)
}
if
!
strings
.
HasPrefix
(
parts
[
0
],
"u:"
)
{
t
.
Errorf
(
"expected prefix 'u:', got: %s"
,
parts
[
0
])
}
}
backend/internal/service/antigravity_gateway_service.go
View file @
d367d1cd
...
...
@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log"
"log/slog"
mathrand
"math/rand"
"net"
"net/http"
...
...
@@ -35,7 +36,7 @@ const (
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
antigravityRateLimitThreshold
=
7
*
time
.
Second
antigravitySmartRetryMinWait
=
1
*
time
.
Second
// 智能重试最小等待时间
antigravitySmartRetryMaxAttempts
=
3
// 智能重试最大次数
antigravitySmartRetryMaxAttempts
=
1
// 智能重试最大次数
(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration
=
30
*
time
.
Second
// 默认限流时间(无 retryDelay 时使用)
// Google RPC 状态和类型常量
...
...
@@ -100,12 +101,11 @@ type antigravityRetryLoopParams struct {
accessToken
string
action
string
body
[]
byte
quotaScope
AntigravityQuotaScope
c
*
gin
.
Context
httpUpstream
HTTPUpstream
settingService
*
SettingService
accountRepo
AccountRepository
// 用于智能重试的模型级别限流
handleError
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
handleError
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
requestedModel
string
// 用于限流检查的原始请求模型
isStickySession
bool
// 是否为粘性会话(用于账号切换时的缓存计费判断)
groupID
int64
// 用于模型级限流时清除粘性会话
...
...
@@ -148,13 +148,17 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if
shouldRateLimitModel
{
log
.
Printf
(
"%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)"
,
p
.
prefix
,
resp
.
StatusCode
,
modelName
,
p
.
account
.
ID
)
rateLimitDuration
:=
waitDuration
if
rateLimitDuration
<=
0
{
rateLimitDuration
=
antigravityDefaultRateLimitDuration
}
log
.
Printf
(
"%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)"
,
p
.
prefix
,
resp
.
StatusCode
,
modelName
,
p
.
account
.
ID
,
rateLimitDuration
,
truncateForLog
(
respBody
,
200
))
resetAt
:=
time
.
Now
()
.
Add
(
antigravityDefaultR
ateLimitDuration
)
resetAt
:=
time
.
Now
()
.
Add
(
r
ateLimitDuration
)
if
!
setModelRateLimitByModelName
(
p
.
ctx
,
p
.
accountRepo
,
p
.
account
.
ID
,
modelName
,
p
.
prefix
,
resp
.
StatusCode
,
resetAt
,
false
)
{
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
p
.
quotaScope
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
log
.
Printf
(
"%s status=%d rate_limited account=%d (no
scope
mapping)"
,
p
.
prefix
,
resp
.
StatusCode
,
p
.
account
.
ID
)
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
p
.
requestedModel
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
log
.
Printf
(
"%s status=%d rate_limited account=%d (no
model
mapping)"
,
p
.
prefix
,
resp
.
StatusCode
,
p
.
account
.
ID
)
}
else
{
s
.
updateAccountModelRateLimitInCache
(
p
.
ctx
,
p
.
account
,
modelName
,
resetAt
)
}
...
...
@@ -190,7 +194,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryReq
,
err
:=
antigravity
.
NewAPIRequestWithURL
(
p
.
ctx
,
baseURL
,
p
.
action
,
p
.
accessToken
,
p
.
body
)
if
err
!=
nil
{
log
.
Printf
(
"%s status=smart_retry_request_build_failed error=%v"
,
p
.
prefix
,
err
)
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
p
.
quotaScope
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
p
.
requestedModel
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
return
&
smartRetryResult
{
action
:
smartRetryActionBreakWithResp
,
resp
:
&
http
.
Response
{
...
...
@@ -233,20 +237,33 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
// 所有重试都失败,限流当前模型并切换账号
log
.
Printf
(
"%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)"
,
p
.
prefix
,
resp
.
StatusCode
,
antigravitySmartRetryMaxAttempts
,
modelName
,
p
.
account
.
ID
)
rateLimitDuration
:=
waitDuration
if
rateLimitDuration
<=
0
{
rateLimitDuration
=
antigravityDefaultRateLimitDuration
}
retryBody
:=
lastRetryBody
if
retryBody
==
nil
{
retryBody
=
respBody
}
log
.
Printf
(
"%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)"
,
p
.
prefix
,
resp
.
StatusCode
,
antigravitySmartRetryMaxAttempts
,
modelName
,
p
.
account
.
ID
,
rateLimitDuration
,
truncateForLog
(
retryBody
,
200
))
resetAt
:=
time
.
Now
()
.
Add
(
antigravityDefaultR
ateLimitDuration
)
resetAt
:=
time
.
Now
()
.
Add
(
r
ateLimitDuration
)
if
p
.
accountRepo
!=
nil
&&
modelName
!=
""
{
if
err
:=
p
.
accountRepo
.
SetModelRateLimit
(
p
.
ctx
,
p
.
account
.
ID
,
modelName
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"%s status=%d model_rate_limit_failed model=%s error=%v"
,
p
.
prefix
,
resp
.
StatusCode
,
modelName
,
err
)
}
else
{
log
.
Printf
(
"%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v"
,
p
.
prefix
,
resp
.
StatusCode
,
modelName
,
p
.
account
.
ID
,
antigravityDefaultR
ateLimitDuration
)
p
.
prefix
,
resp
.
StatusCode
,
modelName
,
p
.
account
.
ID
,
r
ateLimitDuration
)
s
.
updateAccountModelRateLimitInCache
(
p
.
ctx
,
p
.
account
,
modelName
,
resetAt
)
}
}
// 清除粘性会话绑定,避免下次请求仍命中限流账号
if
s
.
cache
!=
nil
&&
p
.
sessionHash
!=
""
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
p
.
ctx
,
p
.
groupID
,
p
.
sessionHash
)
}
// 返回账号切换信号,让上层切换账号重试
return
&
smartRetryResult
{
action
:
smartRetryActionBreakWithResp
,
...
...
@@ -264,22 +281,11 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func
(
s
*
AntigravityGatewayService
)
antigravityRetryLoop
(
p
antigravityRetryLoopParams
)
(
*
antigravityRetryLoopResult
,
error
)
{
// 预检查:如果账号已限流,
根据剩余时间决定等待或切换
// 预检查:如果账号已限流,
直接返回切换信号
if
p
.
requestedModel
!=
""
{
if
remaining
:=
p
.
account
.
GetRateLimitRemainingTimeWithContext
(
p
.
ctx
,
p
.
requestedModel
);
remaining
>
0
{
if
remaining
<
antigravityRateLimitThreshold
{
// 限流剩余时间较短,等待后继续
log
.
Printf
(
"%s pre_check: rate_limit_wait remaining=%v model=%s account=%d"
,
p
.
prefix
,
remaining
.
Truncate
(
time
.
Millisecond
),
p
.
requestedModel
,
p
.
account
.
ID
)
select
{
case
<-
p
.
ctx
.
Done
()
:
return
nil
,
p
.
ctx
.
Err
()
case
<-
time
.
After
(
remaining
)
:
}
}
else
{
// 限流剩余时间较长,返回账号切换信号
log
.
Printf
(
"%s pre_check: rate_limit_switch remaining=%v model=%s account=%d"
,
p
.
prefix
,
remaining
.
Truncate
(
time
.
S
econd
),
p
.
requestedModel
,
p
.
account
.
ID
)
p
.
prefix
,
remaining
.
Truncate
(
time
.
Millis
econd
),
p
.
requestedModel
,
p
.
account
.
ID
)
return
nil
,
&
AntigravityAccountSwitchError
{
OriginalAccountID
:
p
.
account
.
ID
,
RateLimitedModel
:
p
.
requestedModel
,
...
...
@@ -287,7 +293,6 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
}
}
}
}
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
if
len
(
availableURLs
)
==
0
{
...
...
@@ -360,11 +365,26 @@ urlFallbackLoop:
return
nil
,
fmt
.
Errorf
(
"upstream request failed after retries: %w"
,
err
)
}
//
429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
if
resp
.
StatusCode
=
=
http
.
StatusTooManyRequests
||
resp
.
StatusCode
==
http
.
StatusServiceUnavailable
{
//
统一处理错误响应
if
resp
.
StatusCode
>
=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
// ★ 统一入口:自定义错误码 + 临时不可调度
if
handled
,
policyErr
:=
s
.
applyErrorPolicy
(
p
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
);
handled
{
if
policyErr
!=
nil
{
return
nil
,
policyErr
}
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
urlFallbackLoop
}
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
||
resp
.
StatusCode
==
http
.
StatusServiceUnavailable
{
// 尝试智能重试处理(OAuth 账号专用)
smartResult
:=
s
.
handleSmartRetry
(
p
,
resp
,
respBody
,
baseURL
,
urlIdx
,
availableURLs
)
switch
smartResult
.
action
{
...
...
@@ -406,7 +426,7 @@ urlFallbackLoop:
}
// 重试用尽,标记账户限流
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
p
.
quotaScope
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
p
.
requestedModel
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
log
.
Printf
(
"%s status=%d rate_limited base_url=%s body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
baseURL
,
truncateForLog
(
respBody
,
200
))
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
...
...
@@ -416,11 +436,8 @@ urlFallbackLoop:
break
urlFallbackLoop
}
// 其他可重试错误(不包括 429 和 503,因为上面已处理)
if
resp
.
StatusCode
>=
400
&&
shouldRetryAntigravityError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
// 其他可重试错误(500/502/504/529,不包括 429 和 503)
if
shouldRetryAntigravityError
(
resp
.
StatusCode
)
{
if
attempt
<
antigravityMaxRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
...
...
@@ -441,6 +458,9 @@ urlFallbackLoop:
}
continue
}
}
// 其他 4xx 错误或重试用尽,直接返回
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
Header
:
resp
.
Header
.
Clone
(),
...
...
@@ -449,6 +469,7 @@ urlFallbackLoop:
break
urlFallbackLoop
}
// 成功响应(< 400)
break
urlFallbackLoop
}
}
...
...
@@ -581,6 +602,31 @@ func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string {
return
truncateString
(
string
(
body
),
maxBytes
)
}
// checkErrorPolicy nil 安全的包装
func
(
s
*
AntigravityGatewayService
)
checkErrorPolicy
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
body
[]
byte
)
ErrorPolicyResult
{
if
s
.
rateLimitService
==
nil
{
return
ErrorPolicyNone
}
return
s
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
statusCode
,
body
)
}
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
func
(
s
*
AntigravityGatewayService
)
applyErrorPolicy
(
p
antigravityRetryLoopParams
,
statusCode
int
,
headers
http
.
Header
,
respBody
[]
byte
)
(
handled
bool
,
retErr
error
)
{
switch
s
.
checkErrorPolicy
(
p
.
ctx
,
p
.
account
,
statusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
return
true
,
nil
case
ErrorPolicyMatched
:
_
=
p
.
handleError
(
p
.
ctx
,
p
.
prefix
,
p
.
account
,
statusCode
,
headers
,
respBody
,
p
.
requestedModel
,
p
.
groupID
,
p
.
sessionHash
,
p
.
isStickySession
)
return
true
,
nil
case
ErrorPolicyTempUnscheduled
:
slog
.
Info
(
"temp_unschedulable_matched"
,
"prefix"
,
p
.
prefix
,
"status_code"
,
statusCode
,
"account_id"
,
p
.
account
.
ID
)
return
true
,
&
AntigravityAccountSwitchError
{
OriginalAccountID
:
p
.
account
.
ID
,
IsStickySession
:
p
.
isStickySession
}
}
return
false
,
nil
}
// mapAntigravityModel 获取映射后的模型名
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping)
// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号
...
...
@@ -650,6 +696,7 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func
(
s
*
AntigravityGatewayService
)
TestConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
// 获取 token
if
s
.
tokenProvider
==
nil
{
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
...
...
@@ -964,8 +1011,24 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
}
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func
(
s
*
AntigravityGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
isStickySession
bool
)
(
*
ForwardResult
,
error
)
{
// 上游透传账号直接转发,不走 OAuth token 刷新
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
ForwardUpstream
(
ctx
,
c
,
account
,
body
)
}
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
...
...
@@ -983,11 +1046,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if
mappedModel
==
""
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusForbidden
,
"permission_error"
,
fmt
.
Sprintf
(
"model %s not in whitelist"
,
claudeReq
.
Model
))
}
loadModel
:=
mappedModel
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
thinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
mappedModel
=
applyThinkingModelSuffix
(
mappedModel
,
thinkingEnabled
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -1022,11 +1083,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action
:=
"streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if
s
.
cache
!=
nil
{
_
,
_
=
s
.
cache
.
IncrModelCallCount
(
ctx
,
account
.
ID
,
loadModel
)
}
// 执行带重试的请求
result
,
err
:=
s
.
antigravityRetryLoop
(
antigravityRetryLoopParams
{
ctx
:
ctx
,
...
...
@@ -1036,7 +1092,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
accessToken
:
accessToken
,
action
:
action
,
body
:
geminiBody
,
quotaScope
:
quotaScope
,
c
:
c
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
...
...
@@ -1117,7 +1172,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
accessToken
:
accessToken
,
action
:
action
,
body
:
retryGeminiBody
,
quotaScope
:
quotaScope
,
c
:
c
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
...
...
@@ -1228,7 +1282,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
,
0
,
""
,
isStickySession
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
originalModel
,
0
,
""
,
isStickySession
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
...
...
@@ -1258,6 +1312,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
claudeReq
.
Stream
{
// 客户端要求流式,直接透传转换
streamRes
,
err
:=
s
.
handleClaudeStreamingResponse
(
c
,
resp
,
startTime
,
originalModel
)
...
...
@@ -1267,6 +1322,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
}
else
{
// 客户端要求非流式,收集流式响应后转换返回
streamRes
,
err
:=
s
.
handleClaudeStreamToNonStreaming
(
c
,
resp
,
startTime
,
originalModel
)
...
...
@@ -1285,6 +1341,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
},
nil
}
...
...
@@ -1582,211 +1639,20 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return
changed
,
nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
AntigravityQuotaScopeClaude
,
0
,
""
,
false
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
usage
,
firstTokenMs
=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
ClaudeUsage
,
*
int
)
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
var
firstTokenRecorded
bool
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
buf
:=
make
([]
byte
,
0
,
64
*
1024
)
scanner
.
Buffer
(
buf
,
1024
*
1024
)
for
scanner
.
Scan
()
{
line
:=
scanner
.
Bytes
()
// 记录首 token 时间
if
!
firstTokenRecorded
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
firstTokenRecorded
=
true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if
bytes
.
HasPrefix
(
line
,
[]
byte
(
"data: "
))
{
dataStr
:=
bytes
.
TrimPrefix
(
line
,
[]
byte
(
"data: "
))
var
event
map
[
string
]
any
if
json
.
Unmarshal
(
dataStr
,
&
event
)
==
nil
{
if
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
}
}
// 透传行
_
,
_
=
c
.
Writer
.
Write
(
line
)
_
,
_
=
c
.
Writer
.
Write
([]
byte
(
"
\n
"
))
c
.
Writer
.
Flush
()
}
return
usage
,
firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
// ForwardGemini 转发 Gemini 协议请求
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
,
isStickySession
bool
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
...
...
@@ -1799,7 +1665,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
len
(
body
)
==
0
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
// 解析请求以获取 image_size(用于图片计费)
imageSize
:=
s
.
extractImageSize
(
body
)
...
...
@@ -1869,11 +1734,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction
:=
"streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if
s
.
cache
!=
nil
{
_
,
_
=
s
.
cache
.
IncrModelCallCount
(
ctx
,
account
.
ID
,
mappedModel
)
}
// 执行带重试的请求
result
,
err
:=
s
.
antigravityRetryLoop
(
antigravityRetryLoopParams
{
ctx
:
ctx
,
...
...
@@ -1883,7 +1743,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
accessToken
:
accessToken
,
action
:
upstreamAction
,
body
:
wrappedBody
,
quotaScope
:
quotaScope
,
c
:
c
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
...
...
@@ -1957,7 +1816,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
unwrapErr
!=
nil
||
len
(
unwrappedForOps
)
==
0
{
unwrappedForOps
=
respBody
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
,
0
,
""
,
isStickySession
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
originalModel
,
0
,
""
,
isStickySession
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
unwrappedForOps
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
s
.
getUpstreamErrorDetail
(
unwrappedForOps
)
...
...
@@ -2004,6 +1863,7 @@ handleSuccess:
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
stream
{
// 客户端要求流式,直接透传
...
...
@@ -2014,6 +1874,7 @@ handleSuccess:
}
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
}
else
{
// 客户端要求非流式,收集流式响应后返回
streamRes
,
err
:=
s
.
handleGeminiStreamToNonStreaming
(
c
,
resp
,
startTime
)
...
...
@@ -2043,6 +1904,7 @@ handleSuccess:
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
...
...
@@ -2253,9 +2115,9 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
}
// retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认
5 分钟
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认
30s
if
info
.
RetryDelay
>=
antigravityRateLimitThreshold
{
return
false
,
true
,
0
,
info
.
ModelName
return
false
,
true
,
info
.
RetryDelay
,
info
.
ModelName
}
// retryDelay < 阈值:智能重试
...
...
@@ -2377,10 +2239,10 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
,
)
*
handleModelRateLimitResult
{
//
✨
模型级限流处理(
在原有逻辑之前
)
// 模型级限流处理(
优先
)
result
:=
s
.
handleModelRateLimit
(
&
handleModelRateLimitParams
{
ctx
:
ctx
,
prefix
:
prefix
,
...
...
@@ -2402,53 +2264,36 @@ func (s *AntigravityGatewayService) handleUpstreamError(
return
nil
}
// ========== 原有逻辑,保持不变 ==========
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
// 429:尝试解析模型级限流,解析失败时兜底为账号级限流
if
statusCode
==
429
{
// 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。
if
logBody
,
maxBytes
:=
s
.
getLogConfig
();
logBody
{
log
.
Printf
(
"[Antigravity-Debug] 429 response body: %s"
,
truncateString
(
string
(
body
),
maxBytes
))
}
useScopeLimit
:=
quotaScope
!=
""
resetAt
:=
ParseGeminiRateLimitResetTime
(
body
)
if
resetAt
==
nil
{
// 解析失败:使用默认限流时间(与临时限流保持一致)
// 可通过配置或环境变量覆盖
defaultDur
:=
antigravityDefaultRateLimitDuration
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
>
0
{
defaultDur
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
)
*
time
.
Minute
}
// 秒级环境变量优先级最高
if
override
,
ok
:=
antigravityFallbackCooldownSeconds
();
ok
{
defaultDur
=
override
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
defaultDur
:=
s
.
getDefaultRateLimitDuration
()
// 尝试解析模型 key 并设置模型级限流
modelKey
:=
resolveAntigravityModelKey
(
requestedModel
)
if
modelKey
!=
""
{
ra
:=
s
.
resolveResetTime
(
resetAt
,
defaultDur
)
if
err
:=
s
.
accountRepo
.
SetModelRateLimit
(
ctx
,
account
.
ID
,
modelKey
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 model_rate_limit_set_failed model=%s error=%v"
,
prefix
,
modelKey
,
err
)
}
else
{
log
.
Printf
(
"%s status=429 rate_limited account=%d reset_in=%v (fallback)"
,
prefix
,
account
.
ID
,
defaultDur
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed account=%d error=%v"
,
prefix
,
account
.
ID
,
err
)
}
log
.
Printf
(
"%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v"
,
prefix
,
modelKey
,
account
.
ID
,
ra
.
Format
(
"15:04:05"
),
time
.
Until
(
ra
)
.
Truncate
(
time
.
Second
))
s
.
updateAccountModelRateLimitInCache
(
ctx
,
account
,
modelKey
,
ra
)
}
return
nil
}
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v"
,
prefix
,
quotaScope
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
resetTime
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
}
else
{
log
.
Printf
(
"%s status=429 rate_limited account=%d reset_at=%v reset_in=%v"
,
prefix
,
account
.
ID
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
// 无法解析模型 key,兜底为账号级限流
ra
:=
s
.
resolveResetTime
(
resetAt
,
defaultDur
)
log
.
Printf
(
"%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)"
,
prefix
,
account
.
ID
,
ra
.
Format
(
"15:04:05"
),
time
.
Until
(
ra
)
.
Truncate
(
time
.
Second
))
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed account=%d error=%v"
,
prefix
,
account
.
ID
,
err
)
}
}
return
nil
}
// 其他错误码继续使用 rateLimitService
...
...
@@ -2462,11 +2307,92 @@ func (s *AntigravityGatewayService) handleUpstreamError(
return
nil
}
type
antigravityStreamResult
struct
{
usage
*
ClaudeUsage
firstTokenMs
*
int
}
// getDefaultRateLimitDuration 获取默认限流时间
func
(
s
*
AntigravityGatewayService
)
getDefaultRateLimitDuration
()
time
.
Duration
{
defaultDur
:=
antigravityDefaultRateLimitDuration
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
>
0
{
defaultDur
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
)
*
time
.
Minute
}
if
override
,
ok
:=
antigravityFallbackCooldownSeconds
();
ok
{
defaultDur
=
override
}
return
defaultDur
}
// resolveResetTime 根据解析的重置时间或默认时长计算重置时间点
func
(
s
*
AntigravityGatewayService
)
resolveResetTime
(
resetAt
*
int64
,
defaultDur
time
.
Duration
)
time
.
Time
{
if
resetAt
!=
nil
{
return
time
.
Unix
(
*
resetAt
,
0
)
}
return
time
.
Now
()
.
Add
(
defaultDur
)
}
type
antigravityStreamResult
struct
{
usage
*
ClaudeUsage
firstTokenMs
*
int
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。
type
antigravityClientWriter
struct
{
w
gin
.
ResponseWriter
flusher
http
.
Flusher
disconnected
bool
prefix
string
// 日志前缀,标识来源方法
}
func
newAntigravityClientWriter
(
w
gin
.
ResponseWriter
,
flusher
http
.
Flusher
,
prefix
string
)
*
antigravityClientWriter
{
return
&
antigravityClientWriter
{
w
:
w
,
flusher
:
flusher
,
prefix
:
prefix
}
}
// Write 写入数据到客户端,写入失败时标记断开并返回 false
func
(
cw
*
antigravityClientWriter
)
Write
(
p
[]
byte
)
bool
{
if
cw
.
disconnected
{
return
false
}
if
_
,
err
:=
cw
.
w
.
Write
(
p
);
err
!=
nil
{
cw
.
markDisconnected
()
return
false
}
cw
.
flusher
.
Flush
()
return
true
}
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
func
(
cw
*
antigravityClientWriter
)
Fprintf
(
format
string
,
args
...
any
)
bool
{
if
cw
.
disconnected
{
return
false
}
if
_
,
err
:=
fmt
.
Fprintf
(
cw
.
w
,
format
,
args
...
);
err
!=
nil
{
cw
.
markDisconnected
()
return
false
}
cw
.
flusher
.
Flush
()
return
true
}
func
(
cw
*
antigravityClientWriter
)
Disconnected
()
bool
{
return
cw
.
disconnected
}
func
(
cw
*
antigravityClientWriter
)
markDisconnected
()
{
cw
.
disconnected
=
true
log
.
Printf
(
"Client disconnected during streaming (%s), continuing to drain upstream for billing"
,
cw
.
prefix
)
}
// handleStreamReadError 处理上游读取错误的通用逻辑。
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
func
handleStreamReadError
(
err
error
,
clientDisconnected
bool
,
prefix
string
)
(
disconnect
bool
,
handled
bool
)
{
if
errors
.
Is
(
err
,
context
.
Canceled
)
||
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
log
.
Printf
(
"Context canceled during streaming (%s), returning collected usage"
,
prefix
)
return
true
,
true
}
if
clientDisconnected
{
log
.
Printf
(
"Upstream read error after client disconnect (%s): %v, returning collected usage"
,
prefix
,
err
)
return
true
,
true
}
return
false
,
false
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
antigravityStreamResult
,
error
)
{
c
.
Status
(
resp
.
StatusCode
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
...
...
@@ -2542,10 +2468,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
intervalCh
=
intervalTicker
.
C
}
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity gemini"
)
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
cw
.
Disconnected
()
{
return
}
errorEventSent
=
true
...
...
@@ -2557,9 +2485,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
cw
.
Disconnected
()
},
nil
}
if
ev
.
err
!=
nil
{
if
disconnect
,
handled
:=
handleStreamReadError
(
ev
.
err
,
cw
.
Disconnected
(),
"antigravity gemini"
);
handled
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
disconnect
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
...
...
@@ -2574,11 +2505,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
if
payload
==
""
||
payload
==
"[DONE]"
{
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
cw
.
Fprintf
(
"%s
\n
"
,
line
)
continue
}
...
...
@@ -2614,27 +2541,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs
=
&
ms
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
cw
.
Fprintf
(
"data: %s
\n\n
"
,
payload
)
continue
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
cw
.
Fprintf
(
"%s
\n
"
,
line
)
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
if
cw
.
Disconnected
()
{
log
.
Printf
(
"Upstream timeout after client disconnect (antigravity gemini), returning collected usage"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
...
...
@@ -3338,10 +3260,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
intervalCh
=
intervalTicker
.
C
}
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity claude"
)
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
cw
.
Disconnected
()
{
return
}
errorEventSent
=
true
...
...
@@ -3349,19 +3273,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
flusher
.
Flush
()
}
// finishUsage 是获取 processor 最终 usage 的辅助函数
finishUsage
:=
func
()
*
ClaudeUsage
{
_
,
agUsage
:=
processor
.
Finish
()
return
convertUsage
(
agUsage
)
}
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
// 发送结束事件
//
上游完成,
发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
cw
.
Write
(
finalEvents
)
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
cw
.
Disconnected
()
},
nil
}
if
ev
.
err
!=
nil
{
if
disconnect
,
handled
:=
handleStreamReadError
(
ev
.
err
,
cw
.
Disconnected
(),
"antigravity claude"
);
handled
{
return
&
antigravityStreamResult
{
usage
:
finishUsage
(),
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
disconnect
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
...
...
@@ -3371,25 +3303,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
ev
.
line
,
"
\r\n
"
))
if
len
(
claudeEvents
)
>
0
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
if
_
,
writeErr
:=
c
.
Writer
.
Write
(
claudeEvents
);
writeErr
!=
nil
{
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
}
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
}
flusher
.
Flush
()
cw
.
Write
(
claudeEvents
)
}
case
<-
intervalCh
:
...
...
@@ -3397,13 +3318,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
if
cw
.
Disconnected
()
{
log
.
Printf
(
"Upstream timeout after client disconnect (antigravity claude), returning collected usage"
)
return
&
antigravityStreamResult
{
usage
:
finishUsage
(),
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
...
...
@@ -3542,3 +3465,288 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
payload
[
"contents"
]
=
filtered
return
json
.
Marshal
(
payload
)
}
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
originalModel
,
0
,
""
,
false
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
streamRes
:=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
*
antigravityStreamResult
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity upstream"
)
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
cw
.
Disconnected
()}
}
if
ev
.
err
!=
nil
{
if
disconnect
,
handled
:=
handleStreamReadError
(
ev
.
err
,
cw
.
Disconnected
(),
"antigravity upstream"
);
handled
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
disconnect
}
}
log
.
Printf
(
"Stream read error (antigravity upstream): %v"
,
ev
.
err
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
}
}
line
:=
ev
.
line
// 记录首 token 时间
if
firstTokenMs
==
nil
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
s
.
extractSSEUsage
(
line
,
usage
)
// 透传行
cw
.
Fprintf
(
"%s
\n
"
,
line
)
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
if
cw
.
Disconnected
()
{
log
.
Printf
(
"Upstream timeout after client disconnect (antigravity upstream), returning collected usage"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
}
}
log
.
Printf
(
"Stream data interval timeout (antigravity upstream)"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
}
}
}
}
// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景)
func
(
s
*
AntigravityGatewayService
)
extractSSEUsage
(
line
string
,
usage
*
ClaudeUsage
)
{
if
!
strings
.
HasPrefix
(
line
,
"data: "
)
{
return
}
dataStr
:=
strings
.
TrimPrefix
(
line
,
"data: "
)
var
event
map
[
string
]
any
if
json
.
Unmarshal
([]
byte
(
dataStr
),
&
event
)
!=
nil
{
return
}
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
backend/internal/service/antigravity_gateway_service_test.go
View file @
d367d1cd
...
...
@@ -4,18 +4,42 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type
antigravityFailingWriter
struct
{
gin
.
ResponseWriter
failAfter
int
// 允许成功写入的次数,之后所有写入返回错误
writes
int
}
func
(
w
*
antigravityFailingWriter
)
Write
(
p
[]
byte
)
(
int
,
error
)
{
if
w
.
writes
>=
w
.
failAfter
{
return
0
,
errors
.
New
(
"write failed: client disconnected"
)
}
w
.
writes
++
return
w
.
ResponseWriter
.
Write
(
p
)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func
newAntigravityTestService
(
cfg
*
config
.
Config
)
*
AntigravityGatewayService
{
return
&
AntigravityGatewayService
{
settingService
:
&
SettingService
{
cfg
:
cfg
},
}
}
func
TestStripSignatureSensitiveBlocksFromClaudeRequest
(
t
*
testing
.
T
)
{
req
:=
&
antigravity
.
ClaudeRequest
{
Model
:
"claude-sonnet-4-5"
,
...
...
@@ -338,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
//
验证:
ForwardGemini
粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
verifies
//
that
ForwardGemini
sets ForceCacheBilling=true for sticky session switch.
func
TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
...
...
@@ -393,10 +417,16 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
}
func
TestAntigravityStreamUpstreamResponse_UsageAndFirstToken
(
t
*
testing
.
T
)
{
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func
TestStreamUpstreamResponse_UsageAndFirstToken
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
...
...
@@ -404,25 +434,458 @@ func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"
data: {
\
"
usage
\
"
:{
\
"
input_tokens
\
"
:1,
\
"
output_tokens
\
"
:2,
\
"
cache_read_input_tokens
\
"
:3,
\
"
cache_creation_input_tokens
\
"
:4}}
\n
"
)
)
_
,
_
=
pw
.
Write
([]
byte
(
"
data: {
\
"
usage
\
"
:{
\
"
output_tokens
\
"
:5}}
\n
"
)
)
fmt
.
Fprintln
(
pw
,
`
data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}
`
)
fmt
.
Fprintln
(
pw
,
`
data: {"usage":{"output_tokens":5}}
`
)
}()
svc
:=
&
AntigravityGatewayService
{}
start
:=
time
.
Now
()
.
Add
(
-
10
*
time
.
Millisecond
)
usage
,
firstTokenMs
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
start
)
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
start
)
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
usage
)
require
.
Equal
(
t
,
1
,
usage
.
InputTokens
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
1
,
result
.
usage
.
InputTokens
)
// 第二次事件覆盖 output_tokens
require
.
Equal
(
t
,
5
,
usage
.
OutputTokens
)
require
.
Equal
(
t
,
3
,
usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
4
,
usage
.
CacheCreationInputTokens
)
require
.
Equal
(
t
,
5
,
result
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
3
,
result
.
usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
4
,
result
.
usage
.
CacheCreationInputTokens
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
)
if
firstTokenMs
==
nil
{
t
.
Fatalf
(
"expected firstTokenMs to be set"
)
}
// 确保有透传输出
require
.
True
(
t
,
strings
.
Contains
(
writer
.
Body
.
String
(),
"data:"
))
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"data:"
)
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
func
TestStreamUpstreamResponse_NormalComplete
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
fmt
.
Fprintln
(
pw
,
`event: message_start`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
fmt
.
Fprintln
(
pw
,
`event: content_block_delta`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"content_block_delta","delta":{"text":"hello"}}`
)
fmt
.
Fprintln
(
pw
,
""
)
fmt
.
Fprintln
(
pw
,
`event: message_delta`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_delta","usage":{"output_tokens":5}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
,
"normal completion should not set clientDisconnect"
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
5
,
result
.
usage
.
OutputTokens
,
"should collect output_tokens from message_delta"
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
,
"should record first token time"
)
// 验证数据被透传到客户端
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: message_start"
)
require
.
Contains
(
t
,
body
,
"content_block_delta"
)
require
.
Contains
(
t
,
body
,
"message_delta"
)
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
func
TestHandleGeminiStreamingResponse_NormalComplete
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// 第一个 chunk(部分内容)
fmt
.
Fprintln
(
pw
,
`data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`
)
fmt
.
Fprintln
(
pw
,
""
)
// 第二个 chunk(最终内容+完整 usage)
fmt
.
Fprintln
(
pw
,
`data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleGeminiStreamingResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
,
"normal completion should not set clientDisconnect"
)
require
.
NotNil
(
t
,
result
.
usage
)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require
.
Equal
(
t
,
8
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
8
,
result
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
2
,
result
.
usage
.
CacheReadInputTokens
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
,
"should record first token time"
)
// 验证数据被透传到客户端
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"Hello"
)
require
.
Contains
(
t
,
body
,
"world"
)
// 不应包含错误事件
require
.
NotContains
(
t
,
body
,
"event: error"
)
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
func
TestHandleClaudeStreamingResponse_NormalComplete
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
fmt
.
Fprintln
(
pw
,
`data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleClaudeStreamingResponse
(
c
,
resp
,
time
.
Now
(),
"claude-sonnet-4-5"
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
,
"normal completion should not set clientDisconnect"
)
require
.
NotNil
(
t
,
result
.
usage
)
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require
.
Equal
(
t
,
5
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
result
.
usage
.
OutputTokens
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
,
"should record first token time"
)
// 验证输出是 Claude SSE 格式(processor 会转换)
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: message_start"
,
"should contain Claude message_start event"
)
require
.
Contains
(
t
,
body
,
"event: message_stop"
,
"should contain Claude message_stop event"
)
// 不应包含错误事件
require
.
NotContains
(
t
,
body
,
"event: error"
)
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
func
TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
fmt
.
Fprintln
(
pw
,
`event: message_start`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
fmt
.
Fprintln
(
pw
,
`event: message_delta`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_delta","usage":{"output_tokens":20}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
20
,
result
.
usage
.
OutputTokens
)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证:context 取消时返回 usage 且标记 clientDisconnect
func
TestStreamUpstreamResponse_ContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{}}
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"event: error"
)
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func
TestStreamUpstreamResponse_Timeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pw
.
Close
()
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func
TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
// 不关闭 pw → 等待超时
}()
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pw
.
Close
()
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
func
TestHandleGeminiStreamingResponse_ClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
fmt
.
Fprintln
(
pw
,
`data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleGeminiStreamingResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"write_failed"
)
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证:context 取消时不注入错误事件
func
TestHandleGeminiStreamingResponse_ContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{}}
result
,
err
:=
svc
.
handleGeminiStreamingResponse
(
c
,
resp
,
time
.
Now
())
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"event: error"
)
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
func
TestHandleClaudeStreamingResponse_ClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// v1internal 包装格式
fmt
.
Fprintln
(
pw
,
`data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleClaudeStreamingResponse
(
c
,
resp
,
time
.
Now
(),
"claude-sonnet-4-5"
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证:context 取消时不注入错误事件
func
TestHandleClaudeStreamingResponse_ContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{}}
result
,
err
:=
svc
.
handleClaudeStreamingResponse
(
c
,
resp
,
time
.
Now
(),
"claude-sonnet-4-5"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"event: error"
)
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func
TestExtractSSEUsage
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
line
string
expected
ClaudeUsage
}{
{
name
:
"message_delta with output_tokens"
,
line
:
`data: {"type":"message_delta","usage":{"output_tokens":42}}`
,
expected
:
ClaudeUsage
{
OutputTokens
:
42
},
},
{
name
:
"non-data line ignored"
,
line
:
`event: message_start`
,
expected
:
ClaudeUsage
{},
},
{
name
:
"top-level usage with all fields"
,
line
:
`data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`
,
expected
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
20
,
CacheReadInputTokens
:
5
,
CacheCreationInputTokens
:
3
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
usage
:=
&
ClaudeUsage
{}
svc
.
extractSSEUsage
(
tt
.
line
,
usage
)
require
.
Equal
(
t
,
tt
.
expected
,
*
usage
)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func
TestAntigravityClientWriter
(
t
*
testing
.
T
)
{
t
.
Run
(
"normal write succeeds"
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"test"
)
ok
:=
cw
.
Write
([]
byte
(
"hello"
))
require
.
True
(
t
,
ok
)
require
.
False
(
t
,
cw
.
Disconnected
())
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"hello"
)
})
t
.
Run
(
"write failure marks disconnected"
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
fw
:=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
fw
,
flusher
,
"test"
)
ok
:=
cw
.
Write
([]
byte
(
"hello"
))
require
.
False
(
t
,
ok
)
require
.
True
(
t
,
cw
.
Disconnected
())
})
t
.
Run
(
"subsequent writes are no-op"
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
fw
:=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
fw
,
flusher
,
"test"
)
cw
.
Write
([]
byte
(
"first"
))
ok
:=
cw
.
Fprintf
(
"second %d"
,
2
)
require
.
False
(
t
,
ok
)
require
.
True
(
t
,
cw
.
Disconnected
())
})
}
backend/internal/service/antigravity_quota_scope.go
View file @
d367d1cd
...
...
@@ -2,63 +2,23 @@ package service
import
(
"context"
"slices"
"strings"
"time"
)
const
antigravityQuotaScopesKey
=
"antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type
AntigravityQuotaScope
string
const
(
AntigravityQuotaScopeClaude
AntigravityQuotaScope
=
"claude"
AntigravityQuotaScopeGeminiText
AntigravityQuotaScope
=
"gemini_text"
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func
IsScopeSupported
(
supportedScopes
[]
string
,
scope
AntigravityQuotaScope
)
bool
{
if
len
(
supportedScopes
)
==
0
{
// 未配置时默认全部支持
return
true
}
supported
:=
slices
.
Contains
(
supportedScopes
,
string
(
scope
))
return
supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func
ResolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
return
resolveAntigravityQuotaScope
(
requestedModel
)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
model
:=
normalizeAntigravityModelName
(
requestedModel
)
if
model
==
""
{
return
""
,
false
}
switch
{
case
strings
.
HasPrefix
(
model
,
"claude-"
)
:
return
AntigravityQuotaScopeClaude
,
true
case
strings
.
HasPrefix
(
model
,
"gemini-"
)
:
if
isImageGenerationModel
(
model
)
{
return
AntigravityQuotaScopeGeminiImage
,
true
}
return
AntigravityQuotaScopeGeminiText
,
true
default
:
return
""
,
false
}
}
func
normalizeAntigravityModelName
(
model
string
)
string
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
normalized
=
strings
.
TrimPrefix
(
normalized
,
"models/"
)
return
normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
// 返回空字符串表示无法解析
func
resolveAntigravityModelKey
(
requestedModel
string
)
string
{
return
normalizeAntigravityModelName
(
requestedModel
)
}
// IsSchedulableForModel 结合模型级限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func
(
a
*
Account
)
IsSchedulableForModel
(
requestedModel
string
)
bool
{
return
a
.
IsSchedulableForModelWithContext
(
context
.
Background
(),
requestedModel
)
...
...
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
if
a
.
isModelRateLimitedWithContext
(
ctx
,
requestedModel
)
{
return
false
}
if
a
.
Platform
!=
PlatformAntigravity
{
return
true
}
scope
,
ok
:=
resolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
true
}
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
==
nil
{
return
true
}
now
:=
time
.
Now
()
return
!
now
.
Before
(
*
resetAt
)
}
func
(
a
*
Account
)
antigravityQuotaScopeResetAt
(
scope
AntigravityQuotaScope
)
*
time
.
Time
{
if
a
==
nil
||
a
.
Extra
==
nil
||
scope
==
""
{
return
nil
}
rawScopes
,
ok
:=
a
.
Extra
[
antigravityQuotaScopesKey
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
rawScope
,
ok
:=
rawScopes
[
string
(
scope
)]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
resetAtRaw
,
ok
:=
rawScope
[
"rate_limit_reset_at"
]
.
(
string
)
if
!
ok
||
strings
.
TrimSpace
(
resetAtRaw
)
==
""
{
return
nil
}
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtRaw
)
if
err
!=
nil
{
return
nil
}
return
&
resetAt
}
var
antigravityAllScopes
=
[]
AntigravityQuotaScope
{
AntigravityQuotaScopeClaude
,
AntigravityQuotaScopeGeminiText
,
AntigravityQuotaScopeGeminiImage
,
}
func
(
a
*
Account
)
GetAntigravityScopeRateLimits
()
map
[
string
]
int64
{
if
a
==
nil
||
a
.
Platform
!=
PlatformAntigravity
{
return
nil
}
now
:=
time
.
Now
()
result
:=
make
(
map
[
string
]
int64
)
for
_
,
scope
:=
range
antigravityAllScopes
{
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
!=
nil
&&
now
.
Before
(
*
resetAt
)
{
remainingSec
:=
int64
(
time
.
Until
(
*
resetAt
)
.
Seconds
())
if
remainingSec
>
0
{
result
[
string
(
scope
)]
=
remainingSec
}
}
}
if
len
(
result
)
==
0
{
return
nil
}
return
result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func
(
a
*
Account
)
GetQuotaScopeRateLimitRemainingTime
(
requestedModel
string
)
time
.
Duration
{
if
a
==
nil
||
a
.
Platform
!=
PlatformAntigravity
{
return
0
}
scope
,
ok
:=
resolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
0
}
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
==
nil
{
return
0
}
if
remaining
:=
time
.
Until
(
*
resetAt
);
remaining
>
0
{
return
remaining
}
return
0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func
(
a
*
Account
)
GetRateLimitRemainingTime
(
requestedModel
string
)
time
.
Duration
{
return
a
.
GetRateLimitRemainingTimeWithContext
(
context
.
Background
(),
requestedModel
)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流
和模型域限流取最大值
)
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型
级
限流)
// 返回 0 表示未限流或已过期
func
(
a
*
Account
)
GetRateLimitRemainingTimeWithContext
(
ctx
context
.
Context
,
requestedModel
string
)
time
.
Duration
{
if
a
==
nil
{
return
0
}
modelRemaining
:=
a
.
GetModelRateLimitRemainingTimeWithContext
(
ctx
,
requestedModel
)
scopeRemaining
:=
a
.
GetQuotaScopeRateLimitRemainingTime
(
requestedModel
)
if
modelRemaining
>
scopeRemaining
{
return
modelRemaining
}
return
scopeRemaining
return
a
.
GetModelRateLimitRemainingTimeWithContext
(
ctx
,
requestedModel
)
}
backend/internal/service/antigravity_rate_limit_test.go
View file @
d367d1cd
...
...
@@ -65,12 +65,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
return
s
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
type
scopeLimitCall
struct
{
accountID
int64
scope
AntigravityQuotaScope
resetAt
time
.
Time
}
type
rateLimitCall
struct
{
accountID
int64
resetAt
time
.
Time
...
...
@@ -84,16 +78,10 @@ type modelRateLimitCall struct {
type
stubAntigravityAccountRepo
struct
{
AccountRepository
scopeCalls
[]
scopeLimitCall
rateCalls
[]
rateLimitCall
modelRateLimitCalls
[]
modelRateLimitCall
}
func
(
s
*
stubAntigravityAccountRepo
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
s
.
scopeCalls
=
append
(
s
.
scopeCalls
,
scopeLimitCall
{
accountID
:
id
,
scope
:
scope
,
resetAt
:
resetAt
})
return
nil
}
func
(
s
*
stubAntigravityAccountRepo
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
s
.
rateCalls
=
append
(
s
.
rateCalls
,
rateLimitCall
{
accountID
:
id
,
resetAt
:
resetAt
})
return
nil
...
...
@@ -137,10 +125,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
quotaScope
:
AntigravityQuotaScopeClaude
,
httpUpstream
:
upstream
,
requestedModel
:
"claude-sonnet-4-5"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleErrorCalled
=
true
return
nil
},
...
...
@@ -161,23 +148,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require
.
Equal
(
t
,
base2
,
available
[
0
])
}
func
TestAntigravityHandleUpstreamError_UsesScopeLimit
(
t
*
testing
.
T
)
{
// 分区限流始终开启,不再支持通过环境变量关闭
repo
:=
&
stubAntigravityAccountRepo
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
9
,
Name
:
"acc-9"
,
Platform
:
PlatformAntigravity
}
body
:=
buildGeminiRateLimitBody
(
"3s"
)
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusTooManyRequests
,
http
.
Header
{},
body
,
AntigravityQuotaScopeClaude
,
0
,
""
,
false
)
require
.
Len
(
t
,
repo
.
scopeCalls
,
1
)
require
.
Empty
(
t
,
repo
.
rateCalls
)
call
:=
repo
.
scopeCalls
[
0
]
require
.
Equal
(
t
,
account
.
ID
,
call
.
accountID
)
require
.
Equal
(
t
,
AntigravityQuotaScopeClaude
,
call
.
scope
)
require
.
WithinDuration
(
t
,
time
.
Now
()
.
Add
(
3
*
time
.
Second
),
call
.
resetAt
,
2
*
time
.
Second
)
}
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
func
TestHandleUpstreamError_429_ModelRateLimit
(
t
*
testing
.
T
)
{
repo
:=
&
stubAntigravityAccountRepo
{}
...
...
@@ -195,7 +165,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
}
}`
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusTooManyRequests
,
http
.
Header
{},
body
,
AntigravityQuotaScopeClaude
,
0
,
""
,
false
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusTooManyRequests
,
http
.
Header
{},
body
,
"claude-sonnet-4-5"
,
0
,
""
,
false
)
// 应该触发模型限流
require
.
NotNil
(
t
,
result
)
...
...
@@ -206,22 +176,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走
scope 限流
)
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走
模型级限流兜底
)
func
TestHandleUpstreamError_429_NonModelRateLimit
(
t
*
testing
.
T
)
{
repo
:=
&
stubAntigravityAccountRepo
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
2
,
Name
:
"acc-2"
,
Platform
:
PlatformAntigravity
}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→
scope 限流
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→
走模型级限流兜底
body
:=
buildGeminiRateLimitBody
(
"5s"
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusTooManyRequests
,
http
.
Header
{},
body
,
AntigravityQuotaScopeClaude
,
0
,
""
,
false
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusTooManyRequests
,
http
.
Header
{},
body
,
"claude-sonnet-4-5"
,
0
,
""
,
false
)
// 不应该触发模型限流,应该走 scope 限流
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED),
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
require
.
Nil
(
t
,
result
)
require
.
Empty
(
t
,
repo
.
modelRateLimitCalls
)
require
.
Len
(
t
,
repo
.
scopeCalls
,
1
)
require
.
Equal
(
t
,
AntigravityQuotaScopeClaude
,
repo
.
scopeCalls
[
0
]
.
scope
)
require
.
Len
(
t
,
repo
.
modelRateLimitCalls
,
1
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
...
...
@@ -241,7 +211,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
}
}`
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusServiceUnavailable
,
http
.
Header
{},
body
,
AntigravityQuotaScopeGeminiText
,
0
,
""
,
false
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusServiceUnavailable
,
http
.
Header
{},
body
,
"gemini-3-pro-high"
,
0
,
""
,
false
)
// 应该触发模型限流
require
.
NotNil
(
t
,
result
)
...
...
@@ -269,12 +239,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
}
}`
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusServiceUnavailable
,
http
.
Header
{},
body
,
AntigravityQuotaScopeGeminiText
,
0
,
""
,
false
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusServiceUnavailable
,
http
.
Header
{},
body
,
"gemini-3-pro-high"
,
0
,
""
,
false
)
// 503 非模型限流不应该做任何处理
require
.
Nil
(
t
,
result
)
require
.
Empty
(
t
,
repo
.
modelRateLimitCalls
,
"503 non-model rate limit should not trigger model rate limit"
)
require
.
Empty
(
t
,
repo
.
scopeCalls
,
"503 non-model rate limit should not trigger scope rate limit"
)
require
.
Empty
(
t
,
repo
.
rateCalls
,
"503 non-model rate limit should not trigger account rate limit"
)
}
...
...
@@ -287,12 +256,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
// 503 + 空响应体 → 不做任何处理
body
:=
[]
byte
(
`{}`
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusServiceUnavailable
,
http
.
Header
{},
body
,
AntigravityQuotaScopeGeminiText
,
0
,
""
,
false
)
result
:=
svc
.
handleUpstreamError
(
context
.
Background
(),
"[test]"
,
account
,
http
.
StatusServiceUnavailable
,
http
.
Header
{},
body
,
"gemini-3-pro-high"
,
0
,
""
,
false
)
// 503 空响应不应该做任何处理
require
.
Nil
(
t
,
result
)
require
.
Empty
(
t
,
repo
.
modelRateLimitCalls
)
require
.
Empty
(
t
,
repo
.
scopeCalls
)
require
.
Empty
(
t
,
repo
.
rateCalls
)
}
...
...
@@ -313,15 +281,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
require
.
False
(
t
,
account
.
IsSchedulableForModel
(
"gemini-3-flash"
))
account
.
RateLimitResetAt
=
nil
account
.
Extra
=
map
[
string
]
any
{
antigravityQuotaScopesKey
:
map
[
string
]
any
{
"claude"
:
map
[
string
]
any
{
"rate_limit_reset_at"
:
future
.
Format
(
time
.
RFC3339
),
},
},
}
require
.
False
(
t
,
account
.
IsSchedulableForModel
(
"claude-sonnet-4-5"
))
require
.
True
(
t
,
account
.
IsSchedulableForModel
(
"claude-sonnet-4-5"
))
require
.
True
(
t
,
account
.
IsSchedulableForModel
(
"gemini-3-flash"
))
}
...
...
@@ -641,6 +601,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`
,
expectedShouldRetry
:
false
,
expectedShouldRateLimit
:
true
,
minWait
:
7
*
time
.
Second
,
modelName
:
"gemini-pro"
,
},
{
...
...
@@ -658,6 +619,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`
,
expectedShouldRetry
:
false
,
expectedShouldRateLimit
:
true
,
minWait
:
39
*
time
.
Second
,
modelName
:
"gemini-3-pro-high"
,
},
{
...
...
@@ -675,6 +637,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`
,
expectedShouldRetry
:
false
,
expectedShouldRateLimit
:
true
,
minWait
:
30
*
time
.
Second
,
modelName
:
"gemini-2.5-flash"
,
},
{
...
...
@@ -692,6 +655,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`
,
expectedShouldRetry
:
false
,
expectedShouldRateLimit
:
true
,
minWait
:
30
*
time
.
Second
,
modelName
:
"claude-sonnet-4-5"
,
},
}
...
...
@@ -710,6 +674,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
t
.
Errorf
(
"wait = %v, want >= %v"
,
wait
,
tt
.
minWait
)
}
}
if
shouldRateLimit
&&
tt
.
minWait
>
0
{
if
wait
<
tt
.
minWait
{
t
.
Errorf
(
"rate limit wait = %v, want >= %v"
,
wait
,
tt
.
minWait
)
}
}
if
(
shouldRetry
||
shouldRateLimit
)
&&
model
!=
tt
.
modelName
{
t
.
Errorf
(
"modelName = %q, want %q"
,
model
,
tt
.
modelName
)
}
...
...
@@ -809,7 +778,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
require
.
NotEqual
(
t
,
"claude_sonnet"
,
call
.
modelKey
,
"should NOT be scope"
)
}
func
TestAntigravityRetryLoop_PreCheck_
WaitsWhenRemainingBelowThreshol
d
(
t
*
testing
.
T
)
{
func
TestAntigravityRetryLoop_PreCheck_
SwitchesWhenRateLimite
d
(
t
*
testing
.
T
)
{
upstream
:=
&
recordingOKUpstream
{}
account
:=
&
Account
{
ID
:
1
,
...
...
@@ -821,19 +790,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
Extra
:
map
[
string
]
any
{
modelRateLimitsKey
:
map
[
string
]
any
{
"claude-sonnet-4-5"
:
map
[
string
]
any
{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at"
:
time
.
Now
()
.
Add
(
2
*
time
.
Second
)
.
Format
(
time
.
RFC3339
),
},
},
},
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Millisecond
)
defer
cancel
()
svc
:=
&
AntigravityGatewayService
{}
result
,
err
:=
svc
.
antigravityRetryLoop
(
antigravityRetryLoopParams
{
ctx
:
c
tx
,
ctx
:
c
ontext
.
Background
()
,
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
...
...
@@ -842,17 +807,21 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
requestedModel
:
"claude-sonnet-4-5"
,
httpUpstream
:
upstream
,
isStickySession
:
true
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
})
require
.
ErrorIs
(
t
,
err
,
context
.
DeadlineExceeded
)
require
.
Nil
(
t
,
result
)
require
.
Equal
(
t
,
0
,
upstream
.
calls
,
"should not call upstream while waiting on pre-check"
)
var
switchErr
*
AntigravityAccountSwitchError
require
.
ErrorAs
(
t
,
err
,
&
switchErr
)
require
.
Equal
(
t
,
account
.
ID
,
switchErr
.
OriginalAccountID
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
switchErr
.
RateLimitedModel
)
require
.
True
(
t
,
switchErr
.
IsStickySession
)
require
.
Equal
(
t
,
0
,
upstream
.
calls
,
"should not call upstream when switching on pre-check"
)
}
func
TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemaining
AtOrAboveThreshold
(
t
*
testing
.
T
)
{
func
TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemaining
Long
(
t
*
testing
.
T
)
{
upstream
:=
&
recordingOKUpstream
{}
account
:=
&
Account
{
ID
:
2
,
...
...
@@ -881,7 +850,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t
requestedModel
:
"claude-sonnet-4-5"
,
httpUpstream
:
upstream
,
isStickySession
:
true
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
})
...
...
backend/internal/service/antigravity_smart_retry_test.go
View file @
d367d1cd
...
...
@@ -13,6 +13,23 @@ import (
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
// 仅关注 DeleteSessionAccountID 的调用记录
type
stubSmartRetryCache
struct
{
GatewayCache
// 嵌入接口,未实现的方法 panic(确保只调用预期方法)
deleteCalls
[]
deleteSessionCall
}
type
deleteSessionCall
struct
{
groupID
int64
sessionHash
string
}
func
(
c
*
stubSmartRetryCache
)
DeleteSessionAccountID
(
_
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
c
.
deleteCalls
=
append
(
c
.
deleteCalls
,
deleteSessionCall
{
groupID
:
groupID
,
sessionHash
:
sessionHash
})
return
nil
}
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type
mockSmartRetryUpstream
struct
{
responses
[]
*
http
.
Response
...
...
@@ -58,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -110,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
body
:
[]
byte
(
`{"input":"test"}`
),
accountRepo
:
repo
,
isStickySession
:
true
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -177,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func
TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError
(
t
*
testing
.
T
)
{
// 智能重试后仍然返回 429(需要提供
3
个响应,因为智能重试最多
3
次)
// 智能重试后仍然返回 429(需要提供
1
个响应,因为智能重试最多
1
次)
failRespBody
:=
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
...
...
@@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
failResp2
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
failResp3
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
failResp1
,
failResp2
,
failResp3
},
errors
:
[]
error
{
nil
,
nil
,
nil
},
responses
:
[]
*
http
.
Response
{
failResp1
},
errors
:
[]
error
{
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
...
...
@@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Platform
:
PlatformAntigravity
,
}
// 3s < 7s 阈值,应该触发智能重试(最多
3
次)
// 3s < 7s 阈值,应该触发智能重试(最多
1
次)
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
...
...
@@ -262,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
false
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -284,7 +291,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
// 验证模型限流已设置
require
.
Len
(
t
,
repo
.
modelRateLimitCalls
,
1
)
require
.
Equal
(
t
,
"gemini-3-flash"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
require
.
Len
(
t
,
upstream
.
calls
,
3
,
"should have made
thre
e retry call
s
(max attempts)"
)
require
.
Len
(
t
,
upstream
.
calls
,
1
,
"should have made
on
e retry call (max attempts)"
)
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
...
...
@@ -324,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
body
:
[]
byte
(
`{"input":"test"}`
),
accountRepo
:
repo
,
isStickySession
:
true
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -380,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -429,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -480,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
accountRepo
:
repo
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -541,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
true
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
})
...
...
@@ -556,19 +563,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
require
.
True
(
t
,
switchErr
.
IsStickySession
)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func
TestHandleSmartRetry_NetworkError_ContinuesRetry
(
t
*
testing
.
T
)
{
// 第一次网络错误,第二次成功
successResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"result":"ok"}`
)),
}
// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时(maxAttempts=1)直接耗尽重试并切换账号
func
TestHandleSmartRetry_NetworkError_ExhaustsRetry
(
t
*
testing
.
T
)
{
// 唯一一次重试遇到网络错误(nil response)
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
nil
,
successResp
},
//
第一次
返回 nil(模拟网络错误)
errors
:
[]
error
{
nil
,
nil
},
// mock 不返回 error,靠 nil response 触发
responses
:
[]
*
http
.
Response
{
nil
},
// 返回 nil(模拟网络错误)
errors
:
[]
error
{
nil
}
,
// mock 不返回 error,靠 nil response 触发
}
repo
:=
&
stubAntigravityAccountRepo
{}
account
:=
&
Account
{
ID
:
8
,
Name
:
"acc-8"
,
...
...
@@ -600,7 +603,8 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
accountRepo
:
repo
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -612,10 +616,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
smartRetryActionBreakWithResp
,
result
.
action
)
require
.
NotNil
(
t
,
result
.
resp
,
"should return successful response after network error recovery"
)
require
.
Equal
(
t
,
http
.
StatusOK
,
result
.
resp
.
StatusCode
)
require
.
Nil
(
t
,
result
.
switchError
,
"should not return switchError on success"
)
require
.
Len
(
t
,
upstream
.
calls
,
2
,
"should have made two retry calls"
)
require
.
Nil
(
t
,
result
.
resp
,
"should not return resp when switchError is set"
)
require
.
NotNil
(
t
,
result
.
switchError
,
"should return switchError after network error exhausted retry"
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
switchError
.
OriginalAccountID
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
result
.
switchError
.
RateLimitedModel
)
require
.
Len
(
t
,
upstream
.
calls
,
1
,
"should have made one retry call"
)
// 验证模型限流已设置
require
.
Len
(
t
,
repo
.
modelRateLimitCalls
,
1
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
...
...
@@ -653,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
body
:
[]
byte
(
`{"input":"test"}`
),
accountRepo
:
repo
,
isStickySession
:
true
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
...
...
@@ -674,3 +683,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
require
.
Len
(
t
,
repo
.
modelRateLimitCalls
,
1
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
}
// ---------------------------------------------------------------------------
// 以下测试覆盖本次改动:
// 1. antigravitySmartRetryMaxAttempts = 1(仅重试 1 次)
// 2. 智能重试失败后清除粘性会话绑定(DeleteSessionAccountID)
// ---------------------------------------------------------------------------
// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1
func
TestSmartRetryMaxAttempts_VerifyConstant
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
1
,
antigravitySmartRetryMaxAttempts
,
"antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting"
)
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession
// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定
func
TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession
(
t
*
testing
.
T
)
{
failRespBody
:=
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
failResp
},
errors
:
[]
error
{
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
10
,
Name
:
"acc-10"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
true
,
groupID
:
42
,
sessionHash
:
"sticky-hash-abc"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
// 验证返回 switchError
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
smartRetryActionBreakWithResp
,
result
.
action
)
require
.
NotNil
(
t
,
result
.
switchError
)
require
.
True
(
t
,
result
.
switchError
.
IsStickySession
,
"switchError should carry IsStickySession=true"
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
switchError
.
OriginalAccountID
)
// 核心断言:DeleteSessionAccountID 被调用,且参数正确
require
.
Len
(
t
,
cache
.
deleteCalls
,
1
,
"should call DeleteSessionAccountID exactly once"
)
require
.
Equal
(
t
,
int64
(
42
),
cache
.
deleteCalls
[
0
]
.
groupID
)
require
.
Equal
(
t
,
"sticky-hash-abc"
,
cache
.
deleteCalls
[
0
]
.
sessionHash
)
// 验证仅重试 1 次
require
.
Len
(
t
,
upstream
.
calls
,
1
,
"should make exactly 1 retry call (maxAttempts=1)"
)
// 验证模型限流已设置
require
.
Len
(
t
,
repo
.
modelRateLimitCalls
,
1
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
}
// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession
// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountID(sessionHash 为空)
func
TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession
(
t
*
testing
.
T
)
{
failRespBody
:=
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
failResp
},
errors
:
[]
error
{
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
11
,
Name
:
"acc-11"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
false
,
groupID
:
42
,
sessionHash
:
""
,
// 非粘性会话,sessionHash 为空
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
smartRetryActionBreakWithResp
,
result
.
action
)
require
.
NotNil
(
t
,
result
.
switchError
)
require
.
False
(
t
,
result
.
switchError
.
IsStickySession
)
// 核心断言:sessionHash 为空时不应调用 DeleteSessionAccountID
require
.
Len
(
t
,
cache
.
deleteCalls
,
0
,
"should NOT call DeleteSessionAccountID when sessionHash is empty"
)
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic
// 边界:cache 为 nil 时不应 panic
func
TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic
(
t
*
testing
.
T
)
{
failRespBody
:=
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
failResp
},
errors
:
[]
error
{
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
account
:=
&
Account
{
ID
:
12
,
Name
:
"acc-12"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
true
,
groupID
:
42
,
sessionHash
:
"sticky-hash-nil-cache"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
// cache 为 nil,不应 panic
svc
:=
&
AntigravityGatewayService
{
cache
:
nil
}
require
.
NotPanics
(
t
,
func
()
{
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
smartRetryActionBreakWithResp
,
result
.
action
)
require
.
NotNil
(
t
,
result
.
switchError
)
require
.
True
(
t
,
result
.
switchError
.
IsStickySession
)
})
}
// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession
// 重试成功时不应清除粘性会话(只有失败才清除)
func
TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession
(
t
*
testing
.
T
)
{
successResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"result":"ok"}`
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
successResp
},
errors
:
[]
error
{
nil
},
}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
13
,
Name
:
"acc-13"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
isStickySession
:
true
,
groupID
:
42
,
sessionHash
:
"sticky-hash-success"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
smartRetryActionBreakWithResp
,
result
.
action
)
require
.
NotNil
(
t
,
result
.
resp
,
"should return successful response"
)
require
.
Equal
(
t
,
http
.
StatusOK
,
result
.
resp
.
StatusCode
)
require
.
Nil
(
t
,
result
.
switchError
,
"should not return switchError on success"
)
// 核心断言:重试成功时不应清除粘性会话
require
.
Len
(
t
,
cache
.
deleteCalls
,
0
,
"should NOT call DeleteSessionAccountID on successful retry"
)
}
// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry
// 长延迟路径(情况1)在 handleSmartRetry 中不直接调用 DeleteSessionAccountID
// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理)
func
TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry
(
t
*
testing
.
T
)
{
repo
:=
&
stubAntigravityAccountRepo
{}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
14
,
Name
:
"acc-14"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
// 15s >= 7s 阈值 → 走长延迟路径
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
accountRepo
:
repo
,
isStickySession
:
true
,
groupID
:
42
,
sessionHash
:
"sticky-hash-long-delay"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
smartRetryActionBreakWithResp
,
result
.
action
)
require
.
NotNil
(
t
,
result
.
switchError
)
require
.
True
(
t
,
result
.
switchError
.
IsStickySession
)
// 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID
// (由上游 handler 的 shouldClearStickySession 处理)
require
.
Len
(
t
,
cache
.
deleteCalls
,
0
,
"long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)"
)
}
// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession
// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定
func
TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession
(
t
*
testing
.
T
)
{
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
nil
},
// 网络错误
errors
:
[]
error
{
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
15
,
Name
:
"acc-15"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
respBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
true
,
groupID
:
99
,
sessionHash
:
"sticky-net-error"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
switchError
)
require
.
True
(
t
,
result
.
switchError
.
IsStickySession
)
// 核心断言:网络错误耗尽重试后也应清除粘性绑定
require
.
Len
(
t
,
cache
.
deleteCalls
,
1
,
"should call DeleteSessionAccountID after network error exhausts retry"
)
require
.
Equal
(
t
,
int64
(
99
),
cache
.
deleteCalls
[
0
]
.
groupID
)
require
.
Equal
(
t
,
"sticky-net-error"
,
cache
.
deleteCalls
[
0
]
.
sessionHash
)
}
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func
TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
(
t
*
testing
.
T
)
{
failRespBody
:=
`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
failResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusServiceUnavailable
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
failRespBody
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
failResp
},
errors
:
[]
error
{
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
16
,
Name
:
"acc-16"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
}
respBody
:=
[]
byte
(
`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusServiceUnavailable
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
params
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
true
,
groupID
:
77
,
sessionHash
:
"sticky-503-short"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
}
availableURLs
:=
[]
string
{
"https://ag-1.test"
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
:=
svc
.
handleSmartRetry
(
params
,
resp
,
respBody
,
"https://ag-1.test"
,
0
,
availableURLs
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
switchError
)
require
.
True
(
t
,
result
.
switchError
.
IsStickySession
)
// 验证粘性绑定被清除
require
.
Len
(
t
,
cache
.
deleteCalls
,
1
)
require
.
Equal
(
t
,
int64
(
77
),
cache
.
deleteCalls
[
0
]
.
groupID
)
require
.
Equal
(
t
,
"sticky-503-short"
,
cache
.
deleteCalls
[
0
]
.
sessionHash
)
// 验证模型限流已设置
require
.
Len
(
t
,
repo
.
modelRateLimitCalls
,
1
)
require
.
Equal
(
t
,
"gemini-3-pro"
,
repo
.
modelRateLimitCalls
[
0
]
.
modelKey
)
}
// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates
// 集成测试:antigravityRetryLoop → handleSmartRetry → switchError 传播
// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除
func
TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates
(
t
*
testing
.
T
)
{
// 初始 429 响应
initialRespBody
:=
[]
byte
(
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
)
initialResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
initialRespBody
)),
}
// 智能重试也返回 429
retryRespBody
:=
`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
retryResp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
retryRespBody
)),
}
upstream
:=
&
mockSmartRetryUpstream
{
responses
:
[]
*
http
.
Response
{
initialResp
,
retryResp
},
errors
:
[]
error
{
nil
,
nil
},
}
repo
:=
&
stubAntigravityAccountRepo
{}
cache
:=
&
stubSmartRetryCache
{}
account
:=
&
Account
{
ID
:
17
,
Name
:
"acc-17"
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Schedulable
:
true
,
Status
:
StatusActive
,
Concurrency
:
1
,
}
svc
:=
&
AntigravityGatewayService
{
cache
:
cache
}
result
,
err
:=
svc
.
antigravityRetryLoop
(
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
accountRepo
:
repo
,
isStickySession
:
true
,
groupID
:
55
,
sessionHash
:
"sticky-loop-test"
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
return
nil
},
})
require
.
Nil
(
t
,
result
,
"should not return result when switchError"
)
require
.
NotNil
(
t
,
err
,
"should return error"
)
var
switchErr
*
AntigravityAccountSwitchError
require
.
ErrorAs
(
t
,
err
,
&
switchErr
,
"error should be AntigravityAccountSwitchError"
)
require
.
Equal
(
t
,
account
.
ID
,
switchErr
.
OriginalAccountID
)
require
.
Equal
(
t
,
"claude-opus-4-6"
,
switchErr
.
RateLimitedModel
)
require
.
True
(
t
,
switchErr
.
IsStickySession
,
"IsStickySession must propagate through retryLoop"
)
// 验证粘性绑定被清除
require
.
Len
(
t
,
cache
.
deleteCalls
,
1
,
"should clear sticky session in handleSmartRetry"
)
require
.
Equal
(
t
,
int64
(
55
),
cache
.
deleteCalls
[
0
]
.
groupID
)
require
.
Equal
(
t
,
"sticky-loop-test"
,
cache
.
deleteCalls
[
0
]
.
sessionHash
)
}
backend/internal/service/crs_sync_helpers_test.go
0 → 100644
View file @
d367d1cd
package
service
import
(
"testing"
)
func
TestBuildSelectedSet
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
ids
[]
string
wantNil
bool
wantSize
int
}{
{
name
:
"nil input returns nil (backward compatible: create all)"
,
ids
:
nil
,
wantNil
:
true
,
},
{
name
:
"empty slice returns empty map (create none)"
,
ids
:
[]
string
{},
wantNil
:
false
,
wantSize
:
0
,
},
{
name
:
"single ID"
,
ids
:
[]
string
{
"abc-123"
},
wantNil
:
false
,
wantSize
:
1
,
},
{
name
:
"multiple IDs"
,
ids
:
[]
string
{
"a"
,
"b"
,
"c"
},
wantNil
:
false
,
wantSize
:
3
,
},
{
name
:
"duplicate IDs are deduplicated"
,
ids
:
[]
string
{
"a"
,
"a"
,
"b"
},
wantNil
:
false
,
wantSize
:
2
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
buildSelectedSet
(
tt
.
ids
)
if
tt
.
wantNil
{
if
got
!=
nil
{
t
.
Errorf
(
"buildSelectedSet(%v) = %v, want nil"
,
tt
.
ids
,
got
)
}
return
}
if
got
==
nil
{
t
.
Fatalf
(
"buildSelectedSet(%v) = nil, want non-nil map"
,
tt
.
ids
)
}
if
len
(
got
)
!=
tt
.
wantSize
{
t
.
Errorf
(
"buildSelectedSet(%v) has %d entries, want %d"
,
tt
.
ids
,
len
(
got
),
tt
.
wantSize
)
}
// Verify all unique IDs are present
for
_
,
id
:=
range
tt
.
ids
{
if
_
,
ok
:=
got
[
id
];
!
ok
{
t
.
Errorf
(
"buildSelectedSet(%v) missing key %q"
,
tt
.
ids
,
id
)
}
}
})
}
}
func
TestShouldCreateAccount
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
crsID
string
selectedSet
map
[
string
]
struct
{}
want
bool
}{
{
name
:
"nil set allows all (backward compatible)"
,
crsID
:
"any-id"
,
selectedSet
:
nil
,
want
:
true
,
},
{
name
:
"empty set blocks all"
,
crsID
:
"any-id"
,
selectedSet
:
map
[
string
]
struct
{}{},
want
:
false
,
},
{
name
:
"ID in set is allowed"
,
crsID
:
"abc-123"
,
selectedSet
:
map
[
string
]
struct
{}{
"abc-123"
:
{},
"def-456"
:
{}},
want
:
true
,
},
{
name
:
"ID not in set is blocked"
,
crsID
:
"xyz-789"
,
selectedSet
:
map
[
string
]
struct
{}{
"abc-123"
:
{},
"def-456"
:
{}},
want
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
shouldCreateAccount
(
tt
.
crsID
,
tt
.
selectedSet
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"shouldCreateAccount(%q, %v) = %v, want %v"
,
tt
.
crsID
,
tt
.
selectedSet
,
got
,
tt
.
want
)
}
})
}
}
backend/internal/service/crs_sync_service.go
View file @
d367d1cd
...
...
@@ -49,6 +49,7 @@ type SyncFromCRSInput struct {
Username
string
Password
string
SyncProxies
bool
SelectedAccountIDs
[]
string
// if non-empty, only create new accounts with these CRS IDs
}
type
SyncFromCRSItemResult
struct
{
...
...
@@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct {
Extra
map
[
string
]
any
`json:"extra"`
}
func
(
s
*
CRSSyncService
)
SyncFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
SyncFromCRSResult
,
error
)
{
// fetchCRSExport validates the connection parameters, authenticates with CRS,
// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS.
func
(
s
*
CRSSyncService
)
fetchCRSExport
(
ctx
context
.
Context
,
baseURL
,
username
,
password
string
)
(
*
crsExportResponse
,
error
)
{
if
s
.
cfg
==
nil
{
return
nil
,
errors
.
New
(
"config is not available"
)
}
base
URL
:=
strings
.
TrimSpace
(
input
.
B
aseURL
)
normalized
URL
:=
strings
.
TrimSpace
(
b
aseURL
)
if
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
normalizeBaseURL
(
base
URL
,
s
.
cfg
.
Security
.
URLAllowlist
.
CRSHosts
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
)
normalized
,
err
:=
normalizeBaseURL
(
normalized
URL
,
s
.
cfg
.
Security
.
URLAllowlist
.
CRSHosts
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
)
if
err
!=
nil
{
return
nil
,
err
}
base
URL
=
normalized
normalized
URL
=
normalized
}
else
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
base
URL
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
normalized
URL
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
base
URL
=
normalized
normalized
URL
=
normalized
}
if
strings
.
TrimSpace
(
input
.
U
sername
)
==
""
||
strings
.
TrimSpace
(
input
.
P
assword
)
==
""
{
if
strings
.
TrimSpace
(
u
sername
)
==
""
||
strings
.
TrimSpace
(
p
assword
)
==
""
{
return
nil
,
errors
.
New
(
"username and password are required"
)
}
...
...
@@ -221,12 +224,16 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
client
=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
}
adminToken
,
err
:=
crsLogin
(
ctx
,
client
,
baseURL
,
input
.
U
sername
,
input
.
P
assword
)
adminToken
,
err
:=
crsLogin
(
ctx
,
client
,
normalizedURL
,
u
sername
,
p
assword
)
if
err
!=
nil
{
return
nil
,
err
}
exported
,
err
:=
crsExportAccounts
(
ctx
,
client
,
baseURL
,
adminToken
)
return
crsExportAccounts
(
ctx
,
client
,
normalizedURL
,
adminToken
)
}
func
(
s
*
CRSSyncService
)
SyncFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
SyncFromCRSResult
,
error
)
{
exported
,
err
:=
s
.
fetchCRSExport
(
ctx
,
input
.
BaseURL
,
input
.
Username
,
input
.
Password
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
),
}
selectedSet
:=
buildSelectedSet
(
input
.
SelectedAccountIDs
)
var
proxies
[]
Proxy
if
input
.
SyncProxies
{
proxies
,
_
=
s
.
proxyRepo
.
ListActive
(
ctx
)
...
...
@@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
existing
==
nil
{
if
!
shouldCreateAccount
(
src
.
ID
,
selectedSet
)
{
item
.
Action
=
"skipped"
item
.
Error
=
"not selected"
result
.
Skipped
++
result
.
Items
=
append
(
result
.
Items
,
item
)
continue
}
account
:=
&
Account
{
Name
:
defaultName
(
src
.
Name
,
src
.
ID
),
Platform
:
PlatformAnthropic
,
...
...
@@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
existing
==
nil
{
if
!
shouldCreateAccount
(
src
.
ID
,
selectedSet
)
{
item
.
Action
=
"skipped"
item
.
Error
=
"not selected"
result
.
Skipped
++
result
.
Items
=
append
(
result
.
Items
,
item
)
continue
}
account
:=
&
Account
{
Name
:
defaultName
(
src
.
Name
,
src
.
ID
),
Platform
:
PlatformAnthropic
,
...
...
@@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
existing
==
nil
{
if
!
shouldCreateAccount
(
src
.
ID
,
selectedSet
)
{
item
.
Action
=
"skipped"
item
.
Error
=
"not selected"
result
.
Skipped
++
result
.
Items
=
append
(
result
.
Items
,
item
)
continue
}
account
:=
&
Account
{
Name
:
defaultName
(
src
.
Name
,
src
.
ID
),
Platform
:
PlatformOpenAI
,
...
...
@@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
existing
==
nil
{
if
!
shouldCreateAccount
(
src
.
ID
,
selectedSet
)
{
item
.
Action
=
"skipped"
item
.
Error
=
"not selected"
result
.
Skipped
++
result
.
Items
=
append
(
result
.
Items
,
item
)
continue
}
account
:=
&
Account
{
Name
:
defaultName
(
src
.
Name
,
src
.
ID
),
Platform
:
PlatformOpenAI
,
...
...
@@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
existing
==
nil
{
if
!
shouldCreateAccount
(
src
.
ID
,
selectedSet
)
{
item
.
Action
=
"skipped"
item
.
Error
=
"not selected"
result
.
Skipped
++
result
.
Items
=
append
(
result
.
Items
,
item
)
continue
}
account
:=
&
Account
{
Name
:
defaultName
(
src
.
Name
,
src
.
ID
),
Platform
:
PlatformGemini
,
...
...
@@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if
existing
==
nil
{
if
!
shouldCreateAccount
(
src
.
ID
,
selectedSet
)
{
item
.
Action
=
"skipped"
item
.
Error
=
"not selected"
result
.
Skipped
++
result
.
Items
=
append
(
result
.
Items
,
item
)
continue
}
account
:=
&
Account
{
Name
:
defaultName
(
src
.
Name
,
src
.
ID
),
Platform
:
PlatformGemini
,
...
...
@@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
return
newCredentials
}
// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup.
// Returns nil if ids is nil (field not sent → backward compatible: create all).
// Returns an empty map if ids is non-nil but empty (user selected none → create none).
func
buildSelectedSet
(
ids
[]
string
)
map
[
string
]
struct
{}
{
if
ids
==
nil
{
return
nil
}
set
:=
make
(
map
[
string
]
struct
{},
len
(
ids
))
for
_
,
id
:=
range
ids
{
set
[
id
]
=
struct
{}{}
}
return
set
}
// shouldCreateAccount checks if a new CRS account should be created based on user selection.
// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set.
func
shouldCreateAccount
(
crsID
string
,
selectedSet
map
[
string
]
struct
{})
bool
{
if
selectedSet
==
nil
{
return
true
}
_
,
ok
:=
selectedSet
[
crsID
]
return
ok
}
// PreviewFromCRSResult contains the preview of accounts from CRS before sync.
type
PreviewFromCRSResult
struct
{
NewAccounts
[]
CRSPreviewAccount
`json:"new_accounts"`
ExistingAccounts
[]
CRSPreviewAccount
`json:"existing_accounts"`
}
// CRSPreviewAccount represents a single account in the preview result.
type
CRSPreviewAccount
struct
{
CRSAccountID
string
`json:"crs_account_id"`
Kind
string
`json:"kind"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Type
string
`json:"type"`
}
// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them
// as new or existing by batch-querying local crs_account_id mappings.
func
(
s
*
CRSSyncService
)
PreviewFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
PreviewFromCRSResult
,
error
)
{
exported
,
err
:=
s
.
fetchCRSExport
(
ctx
,
input
.
BaseURL
,
input
.
Username
,
input
.
Password
)
if
err
!=
nil
{
return
nil
,
err
}
// Batch query all existing CRS account IDs
existingCRSIDs
,
err
:=
s
.
accountRepo
.
ListCRSAccountIDs
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to list existing CRS accounts: %w"
,
err
)
}
result
:=
&
PreviewFromCRSResult
{
NewAccounts
:
make
([]
CRSPreviewAccount
,
0
),
ExistingAccounts
:
make
([]
CRSPreviewAccount
,
0
),
}
classify
:=
func
(
crsID
,
kind
,
name
,
platform
,
accountType
string
)
{
preview
:=
CRSPreviewAccount
{
CRSAccountID
:
crsID
,
Kind
:
kind
,
Name
:
defaultName
(
name
,
crsID
),
Platform
:
platform
,
Type
:
accountType
,
}
if
_
,
exists
:=
existingCRSIDs
[
crsID
];
exists
{
result
.
ExistingAccounts
=
append
(
result
.
ExistingAccounts
,
preview
)
}
else
{
result
.
NewAccounts
=
append
(
result
.
NewAccounts
,
preview
)
}
}
for
_
,
src
:=
range
exported
.
Data
.
ClaudeAccounts
{
authType
:=
strings
.
TrimSpace
(
src
.
AuthType
)
if
authType
==
""
{
authType
=
AccountTypeOAuth
}
classify
(
src
.
ID
,
src
.
Kind
,
src
.
Name
,
PlatformAnthropic
,
authType
)
}
for
_
,
src
:=
range
exported
.
Data
.
ClaudeConsoleAccounts
{
classify
(
src
.
ID
,
src
.
Kind
,
src
.
Name
,
PlatformAnthropic
,
AccountTypeAPIKey
)
}
for
_
,
src
:=
range
exported
.
Data
.
OpenAIOAuthAccounts
{
classify
(
src
.
ID
,
src
.
Kind
,
src
.
Name
,
PlatformOpenAI
,
AccountTypeOAuth
)
}
for
_
,
src
:=
range
exported
.
Data
.
OpenAIResponsesAccounts
{
classify
(
src
.
ID
,
src
.
Kind
,
src
.
Name
,
PlatformOpenAI
,
AccountTypeAPIKey
)
}
for
_
,
src
:=
range
exported
.
Data
.
GeminiOAuthAccounts
{
classify
(
src
.
ID
,
src
.
Kind
,
src
.
Name
,
PlatformGemini
,
AccountTypeOAuth
)
}
for
_
,
src
:=
range
exported
.
Data
.
GeminiAPIKeyAccounts
{
classify
(
src
.
ID
,
src
.
Kind
,
src
.
Name
,
PlatformGemini
,
AccountTypeAPIKey
)
}
return
result
,
nil
}
backend/internal/service/digest_session_store.go
0 → 100644
View file @
d367d1cd
package
service
import
(
"strconv"
"strings"
"time"
gocache
"github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const
digestSessionTTL
=
5
*
time
.
Minute
// sessionEntry flat cache 条目
type
sessionEntry
struct
{
uuid
string
accountID
int64
}
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type
DigestSessionStore
struct
{
cache
*
gocache
.
Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func
NewDigestSessionStore
()
*
DigestSessionStore
{
return
&
DigestSessionStore
{
cache
:
gocache
.
New
(
digestSessionTTL
,
time
.
Minute
),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
func
(
s
*
DigestSessionStore
)
Save
(
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
,
oldDigestChain
string
)
{
if
digestChain
==
""
{
return
}
ns
:=
buildNS
(
groupID
,
prefixHash
)
s
.
cache
.
Set
(
ns
+
digestChain
,
&
sessionEntry
{
uuid
:
uuid
,
accountID
:
accountID
},
gocache
.
DefaultExpiration
)
if
oldDigestChain
!=
""
&&
oldDigestChain
!=
digestChain
{
s
.
cache
.
Delete
(
ns
+
oldDigestChain
)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func
(
s
*
DigestSessionStore
)
Find
(
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
matchedChain
string
,
found
bool
)
{
if
digestChain
==
""
{
return
""
,
0
,
""
,
false
}
ns
:=
buildNS
(
groupID
,
prefixHash
)
chain
:=
digestChain
for
{
if
val
,
ok
:=
s
.
cache
.
Get
(
ns
+
chain
);
ok
{
if
e
,
ok
:=
val
.
(
*
sessionEntry
);
ok
{
return
e
.
uuid
,
e
.
accountID
,
chain
,
true
}
}
i
:=
strings
.
LastIndex
(
chain
,
"-"
)
if
i
<
0
{
return
""
,
0
,
""
,
false
}
chain
=
chain
[
:
i
]
}
}
// buildNS 构建 namespace 前缀
func
buildNS
(
groupID
int64
,
prefixHash
string
)
string
{
return
strconv
.
FormatInt
(
groupID
,
10
)
+
":"
+
prefixHash
+
"|"
}
backend/internal/service/digest_session_store_test.go
0 → 100644
View file @
d367d1cd
//go:build unit
package
service
import
(
"fmt"
"sync"
"testing"
"time"
gocache
"github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestDigestSessionStore_SaveAndFind
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
store
.
Save
(
1
,
"prefix"
,
"s:a1-u:b2-m:c3"
,
"uuid-1"
,
100
,
""
)
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"s:a1-u:b2-m:c3"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-1"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
100
),
accountID
)
}
func
TestDigestSessionStore_PrefixMatch
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 保存短链
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-short"
,
10
,
""
)
// 用长链查找,应前缀匹配到短链
uuid
,
accountID
,
matchedChain
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b-u:c-m:d"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-short"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
10
),
accountID
)
assert
.
Equal
(
t
,
"u:a-m:b"
,
matchedChain
)
}
func
TestDigestSessionStore_LongestPrefixMatch
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
store
.
Save
(
1
,
"prefix"
,
"u:a"
,
"uuid-1"
,
1
,
""
)
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-2"
,
2
,
""
)
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b-u:c"
,
"uuid-3"
,
3
,
""
)
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b-u:c-m:d-u:e"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-3"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
3
),
accountID
)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid
,
accountID
,
_
,
found
=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b-u:x"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-2"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
2
),
accountID
)
}
func
TestDigestSessionStore_SaveDeletesOldChain
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 第一轮:保存 "u:a-m:b"
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
""
)
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b-u:c-m:d"
,
"uuid-1"
,
100
,
"u:a-m:b"
)
// 旧链 "u:a-m:b" 应已被删除
_
,
_
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b"
)
assert
.
False
(
t
,
found
,
"old chain should be deleted"
)
// 新链应能找到
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b-u:c-m:d"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-1"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
100
),
accountID
)
}
func
TestDigestSessionStore_DifferentSessionsNoInterference
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 相同系统提示词,不同用户提示词
store
.
Save
(
1
,
"prefix"
,
"s:sys-u:user1"
,
"uuid-1"
,
100
,
""
)
store
.
Save
(
1
,
"prefix"
,
"s:sys-u:user2"
,
"uuid-2"
,
200
,
""
)
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"s:sys-u:user1-m:reply1"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-1"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
100
),
accountID
)
uuid
,
accountID
,
_
,
found
=
store
.
Find
(
1
,
"prefix"
,
"s:sys-u:user2-m:reply2"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-2"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
200
),
accountID
)
}
func
TestDigestSessionStore_NoMatch
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
""
)
// 完全不同的 chain
_
,
_
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:x-m:y"
)
assert
.
False
(
t
,
found
)
}
func
TestDigestSessionStore_DifferentPrefixHash
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
store
.
Save
(
1
,
"prefix1"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
""
)
// 不同 prefixHash 应隔离
_
,
_
,
_
,
found
:=
store
.
Find
(
1
,
"prefix2"
,
"u:a-m:b"
)
assert
.
False
(
t
,
found
)
}
func
TestDigestSessionStore_DifferentGroupID
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
""
)
// 不同 groupID 应隔离
_
,
_
,
_
,
found
:=
store
.
Find
(
2
,
"prefix"
,
"u:a-m:b"
)
assert
.
False
(
t
,
found
)
}
func
TestDigestSessionStore_EmptyDigestChain
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 空链不应保存
store
.
Save
(
1
,
"prefix"
,
""
,
"uuid-1"
,
100
,
""
)
_
,
_
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
""
)
assert
.
False
(
t
,
found
)
}
func
TestDigestSessionStore_TTLExpiration
(
t
*
testing
.
T
)
{
store
:=
&
DigestSessionStore
{
cache
:
gocache
.
New
(
100
*
time
.
Millisecond
,
50
*
time
.
Millisecond
),
}
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
""
)
// 立即应该能找到
_
,
_
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b"
)
require
.
True
(
t
,
found
)
// 等待过期 + 清理周期
time
.
Sleep
(
300
*
time
.
Millisecond
)
// 过期后应找不到
_
,
_
,
_
,
found
=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b"
)
assert
.
False
(
t
,
found
)
}
func
TestDigestSessionStore_ConcurrentSafety
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
var
wg
sync
.
WaitGroup
const
goroutines
=
50
const
operations
=
100
wg
.
Add
(
goroutines
)
for
g
:=
0
;
g
<
goroutines
;
g
++
{
go
func
(
id
int
)
{
defer
wg
.
Done
()
prefix
:=
fmt
.
Sprintf
(
"prefix-%d"
,
id
%
5
)
for
i
:=
0
;
i
<
operations
;
i
++
{
chain
:=
fmt
.
Sprintf
(
"u:%d-m:%d"
,
id
,
i
)
uuid
:=
fmt
.
Sprintf
(
"uuid-%d-%d"
,
id
,
i
)
store
.
Save
(
1
,
prefix
,
chain
,
uuid
,
int64
(
id
),
""
)
store
.
Find
(
1
,
prefix
,
chain
)
}
}(
g
)
}
wg
.
Wait
()
}
func
TestDigestSessionStore_MultipleSessions
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
sessions
:=
[]
struct
{
chain
string
uuid
string
accountID
int64
}{
{
"u:session1"
,
"uuid-1"
,
1
},
{
"u:session2-m:reply2"
,
"uuid-2"
,
2
},
{
"u:session3-m:reply3-u:msg3"
,
"uuid-3"
,
3
},
}
for
_
,
sess
:=
range
sessions
{
store
.
Save
(
1
,
"prefix"
,
sess
.
chain
,
sess
.
uuid
,
sess
.
accountID
,
""
)
}
// 验证每个会话都能正确查找
for
_
,
sess
:=
range
sessions
{
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
sess
.
chain
)
require
.
True
(
t
,
found
,
"should find session: %s"
,
sess
.
chain
)
assert
.
Equal
(
t
,
sess
.
uuid
,
uuid
)
assert
.
Equal
(
t
,
sess
.
accountID
,
accountID
)
}
// 验证继续对话的场景
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:session2-m:reply2-u:newmsg"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-2"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
2
),
accountID
)
}
func
TestDigestSessionStore_Performance1000Sessions
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 插入 1000 个会话
for
i
:=
0
;
i
<
1000
;
i
++
{
chain
:=
fmt
.
Sprintf
(
"s:sys-u:user%d-m:reply%d"
,
i
,
i
)
store
.
Save
(
1
,
"prefix"
,
chain
,
fmt
.
Sprintf
(
"uuid-%d"
,
i
),
int64
(
i
),
""
)
}
// 查找性能测试
start
:=
time
.
Now
()
const
lookups
=
10000
for
i
:=
0
;
i
<
lookups
;
i
++
{
idx
:=
i
%
1000
chain
:=
fmt
.
Sprintf
(
"s:sys-u:user%d-m:reply%d-u:newmsg"
,
idx
,
idx
)
_
,
_
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
chain
)
assert
.
True
(
t
,
found
)
}
elapsed
:=
time
.
Since
(
start
)
t
.
Logf
(
"%d lookups in %v (%.0f ns/op)"
,
lookups
,
elapsed
,
float64
(
elapsed
.
Nanoseconds
())
/
lookups
)
}
func
TestDigestSessionStore_FindReturnsMatchedChain
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b-u:c"
,
"uuid-1"
,
100
,
""
)
// 精确匹配
_
,
_
,
matchedChain
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b-u:c"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"u:a-m:b-u:c"
,
matchedChain
)
// 前缀匹配(截断后命中)
_
,
_
,
matchedChain
,
found
=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b-u:c-m:d-u:e"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"u:a-m:b-u:c"
,
matchedChain
)
}
func
TestDigestSessionStore_CacheItemCountStable
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for
conv
:=
0
;
conv
<
100
;
conv
++
{
var
prevMatchedChain
string
for
round
:=
0
;
round
<
10
;
round
++
{
chain
:=
fmt
.
Sprintf
(
"s:sys-u:user%d"
,
conv
)
for
r
:=
0
;
r
<
round
;
r
++
{
chain
+=
fmt
.
Sprintf
(
"-m:a%d-u:q%d"
,
r
,
r
+
1
)
}
uuid
:=
fmt
.
Sprintf
(
"uuid-conv%d"
,
conv
)
_
,
_
,
matched
,
_
:=
store
.
Find
(
1
,
"prefix"
,
chain
)
store
.
Save
(
1
,
"prefix"
,
chain
,
uuid
,
int64
(
conv
),
matched
)
prevMatchedChain
=
matched
_
=
prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount
:=
store
.
cache
.
ItemCount
()
assert
.
LessOrEqual
(
t
,
itemCount
,
100
,
"cache should have at most 100 items (1 per conversation), got %d"
,
itemCount
)
t
.
Logf
(
"Cache item count after 100 conversations × 10 rounds: %d"
,
itemCount
)
}
func
TestDigestSessionStore_TTLPreventsUnboundedGrowth
(
t
*
testing
.
T
)
{
// 使用极短 TTL 验证大量写入后 cache 能被清理
store
:=
&
DigestSessionStore
{
cache
:
gocache
.
New
(
100
*
time
.
Millisecond
,
50
*
time
.
Millisecond
),
}
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
for
i
:=
0
;
i
<
500
;
i
++
{
chain
:=
fmt
.
Sprintf
(
"u:user%d"
,
i
)
store
.
Save
(
1
,
"prefix"
,
chain
,
fmt
.
Sprintf
(
"uuid-%d"
,
i
),
int64
(
i
),
""
)
}
assert
.
Equal
(
t
,
500
,
store
.
cache
.
ItemCount
())
// 等待 TTL + 清理周期
time
.
Sleep
(
300
*
time
.
Millisecond
)
assert
.
Equal
(
t
,
0
,
store
.
cache
.
ItemCount
(),
"all items should be expired and cleaned up"
)
}
func
TestDigestSessionStore_SaveSameChainNoDelete
(
t
*
testing
.
T
)
{
store
:=
NewDigestSessionStore
()
// 保存 chain
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
""
)
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
store
.
Save
(
1
,
"prefix"
,
"u:a-m:b"
,
"uuid-1"
,
100
,
"u:a-m:b"
)
// 仍然能找到
uuid
,
accountID
,
_
,
found
:=
store
.
Find
(
1
,
"prefix"
,
"u:a-m:b"
)
require
.
True
(
t
,
found
)
assert
.
Equal
(
t
,
"uuid-1"
,
uuid
)
assert
.
Equal
(
t
,
int64
(
100
),
accountID
)
}
backend/internal/service/error_policy_integration_test.go
0 → 100644
View file @
d367d1cd
//go:build unit
package
service
import
(
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mocks (scoped to this file by naming convention)
// ---------------------------------------------------------------------------
// epFixedUpstream returns a fixed response for every request.
type
epFixedUpstream
struct
{
statusCode
int
body
string
calls
int
}
func
(
u
*
epFixedUpstream
)
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
{
u
.
calls
++
return
&
http
.
Response
{
StatusCode
:
u
.
statusCode
,
Header
:
http
.
Header
{},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
u
.
body
)),
},
nil
}
func
(
u
*
epFixedUpstream
)
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
enableTLSFingerprint
bool
)
(
*
http
.
Response
,
error
)
{
return
u
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
// epAccountRepo records SetTempUnschedulable / SetError calls.
type
epAccountRepo
struct
{
mockAccountRepoForGemini
tempCalls
int
setErrCalls
int
}
func
(
r
*
epAccountRepo
)
SetTempUnschedulable
(
_
context
.
Context
,
_
int64
,
_
time
.
Time
,
_
string
)
error
{
r
.
tempCalls
++
return
nil
}
func
(
r
*
epAccountRepo
)
SetError
(
_
context
.
Context
,
_
int64
,
_
string
)
error
{
r
.
setErrCalls
++
return
nil
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func
saveAndSetBaseURLs
(
t
*
testing
.
T
)
{
t
.
Helper
()
oldBaseURLs
:=
append
([]
string
(
nil
),
antigravity
.
BaseURLs
...
)
oldAvail
:=
antigravity
.
DefaultURLAvailability
antigravity
.
BaseURLs
=
[]
string
{
"https://ep-test.example"
}
antigravity
.
DefaultURLAvailability
=
antigravity
.
NewURLAvailability
(
time
.
Minute
)
t
.
Cleanup
(
func
()
{
antigravity
.
BaseURLs
=
oldBaseURLs
antigravity
.
DefaultURLAvailability
=
oldAvail
})
}
func
newRetryParams
(
account
*
Account
,
upstream
HTTPUpstream
,
handleError
func
(
context
.
Context
,
string
,
*
Account
,
int
,
http
.
Header
,
[]
byte
,
string
,
int64
,
string
,
bool
)
*
handleModelRateLimitResult
)
antigravityRetryLoopParams
{
return
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[ep-test]"
,
account
:
account
,
accessToken
:
"token"
,
action
:
"generateContent"
,
body
:
[]
byte
(
`{"input":"test"}`
),
httpUpstream
:
upstream
,
requestedModel
:
"claude-sonnet-4-5"
,
handleError
:
handleError
,
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
// ---------------------------------------------------------------------------
func
TestRetryLoop_ErrorPolicy_CustomErrorCodes
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
upstreamStatus
int
upstreamBody
string
customCodes
[]
any
expectHandleError
int
expectUpstream
int
expectStatusCode
int
}{
{
name
:
"429_in_custom_codes_matched"
,
upstreamStatus
:
429
,
upstreamBody
:
`{"error":"rate limited"}`
,
customCodes
:
[]
any
{
float64
(
429
)},
expectHandleError
:
1
,
expectUpstream
:
1
,
expectStatusCode
:
429
,
},
{
name
:
"429_not_in_custom_codes_skipped"
,
upstreamStatus
:
429
,
upstreamBody
:
`{"error":"rate limited"}`
,
customCodes
:
[]
any
{
float64
(
500
)},
expectHandleError
:
0
,
expectUpstream
:
1
,
expectStatusCode
:
429
,
},
{
name
:
"500_in_custom_codes_matched"
,
upstreamStatus
:
500
,
upstreamBody
:
`{"error":"internal"}`
,
customCodes
:
[]
any
{
float64
(
500
)},
expectHandleError
:
1
,
expectUpstream
:
1
,
expectStatusCode
:
500
,
},
{
name
:
"500_not_in_custom_codes_skipped"
,
upstreamStatus
:
500
,
upstreamBody
:
`{"error":"internal"}`
,
customCodes
:
[]
any
{
float64
(
429
)},
expectHandleError
:
0
,
expectUpstream
:
1
,
expectStatusCode
:
500
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
saveAndSetBaseURLs
(
t
)
upstream
:=
&
epFixedUpstream
{
statusCode
:
tt
.
upstreamStatus
,
body
:
tt
.
upstreamBody
}
repo
:=
&
epAccountRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
account
:=
&
Account
{
ID
:
100
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Schedulable
:
true
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
tt
.
customCodes
,
},
}
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
rlSvc
}
var
handleErrorCount
int
p
:=
newRetryParams
(
account
,
upstream
,
func
(
_
context
.
Context
,
_
string
,
_
*
Account
,
_
int
,
_
http
.
Header
,
_
[]
byte
,
_
string
,
_
int64
,
_
string
,
_
bool
)
*
handleModelRateLimitResult
{
handleErrorCount
++
return
nil
})
result
,
err
:=
svc
.
antigravityRetryLoop
(
p
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
resp
)
defer
func
()
{
_
=
result
.
resp
.
Body
.
Close
()
}()
require
.
Equal
(
t
,
tt
.
expectStatusCode
,
result
.
resp
.
StatusCode
)
require
.
Equal
(
t
,
tt
.
expectHandleError
,
handleErrorCount
,
"handleError call count"
)
require
.
Equal
(
t
,
tt
.
expectUpstream
,
upstream
.
calls
,
"upstream call count"
)
})
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_TempUnschedulable
// ---------------------------------------------------------------------------
func
TestRetryLoop_ErrorPolicy_TempUnschedulable
(
t
*
testing
.
T
)
{
tempRulesAccount
:=
func
(
rules
[]
any
)
*
Account
{
return
&
Account
{
ID
:
200
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Schedulable
:
true
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
rules
,
},
}
}
overloadedRule
:=
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
}
rateLimitRule
:=
map
[
string
]
any
{
"error_code"
:
float64
(
429
),
"keywords"
:
[]
any
{
"rate limited keyword"
},
"duration_minutes"
:
float64
(
5
),
}
t
.
Run
(
"503_overloaded_matches_rule"
,
func
(
t
*
testing
.
T
)
{
saveAndSetBaseURLs
(
t
)
upstream
:=
&
epFixedUpstream
{
statusCode
:
503
,
body
:
`overloaded`
}
repo
:=
&
epAccountRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
rlSvc
}
account
:=
tempRulesAccount
([]
any
{
overloadedRule
})
p
:=
newRetryParams
(
account
,
upstream
,
func
(
_
context
.
Context
,
_
string
,
_
*
Account
,
_
int
,
_
http
.
Header
,
_
[]
byte
,
_
string
,
_
int64
,
_
string
,
_
bool
)
*
handleModelRateLimitResult
{
t
.
Error
(
"handleError should not be called for temp unschedulable"
)
return
nil
})
result
,
err
:=
svc
.
antigravityRetryLoop
(
p
)
require
.
Nil
(
t
,
result
)
var
switchErr
*
AntigravityAccountSwitchError
require
.
ErrorAs
(
t
,
err
,
&
switchErr
)
require
.
Equal
(
t
,
account
.
ID
,
switchErr
.
OriginalAccountID
)
require
.
Equal
(
t
,
1
,
upstream
.
calls
,
"should not retry"
)
})
t
.
Run
(
"429_rate_limited_keyword_matches_rule"
,
func
(
t
*
testing
.
T
)
{
saveAndSetBaseURLs
(
t
)
upstream
:=
&
epFixedUpstream
{
statusCode
:
429
,
body
:
`rate limited keyword`
}
repo
:=
&
epAccountRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
rlSvc
}
account
:=
tempRulesAccount
([]
any
{
rateLimitRule
})
p
:=
newRetryParams
(
account
,
upstream
,
func
(
_
context
.
Context
,
_
string
,
_
*
Account
,
_
int
,
_
http
.
Header
,
_
[]
byte
,
_
string
,
_
int64
,
_
string
,
_
bool
)
*
handleModelRateLimitResult
{
t
.
Error
(
"handleError should not be called for temp unschedulable"
)
return
nil
})
result
,
err
:=
svc
.
antigravityRetryLoop
(
p
)
require
.
Nil
(
t
,
result
)
var
switchErr
*
AntigravityAccountSwitchError
require
.
ErrorAs
(
t
,
err
,
&
switchErr
)
require
.
Equal
(
t
,
account
.
ID
,
switchErr
.
OriginalAccountID
)
require
.
Equal
(
t
,
1
,
upstream
.
calls
,
"should not retry"
)
})
t
.
Run
(
"503_body_no_match_continues_default_retry"
,
func
(
t
*
testing
.
T
)
{
saveAndSetBaseURLs
(
t
)
upstream
:=
&
epFixedUpstream
{
statusCode
:
503
,
body
:
`random`
}
repo
:=
&
epAccountRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
rlSvc
}
account
:=
tempRulesAccount
([]
any
{
overloadedRule
})
// Use a short-lived context: the backoff sleep (~1s) will be
// interrupted, proving the code entered the default retry path
// instead of breaking early via error policy.
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
100
*
time
.
Millisecond
)
defer
cancel
()
p
:=
newRetryParams
(
account
,
upstream
,
func
(
_
context
.
Context
,
_
string
,
_
*
Account
,
_
int
,
_
http
.
Header
,
_
[]
byte
,
_
string
,
_
int64
,
_
string
,
_
bool
)
*
handleModelRateLimitResult
{
return
nil
})
p
.
ctx
=
ctx
result
,
err
:=
svc
.
antigravityRetryLoop
(
p
)
// Context cancellation during backoff proves default retry was entered
require
.
Nil
(
t
,
result
)
require
.
ErrorIs
(
t
,
err
,
context
.
DeadlineExceeded
)
require
.
GreaterOrEqual
(
t
,
upstream
.
calls
,
1
,
"should have called upstream at least once"
)
})
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NilRateLimitService
// ---------------------------------------------------------------------------
func
TestRetryLoop_ErrorPolicy_NilRateLimitService
(
t
*
testing
.
T
)
{
saveAndSetBaseURLs
(
t
)
upstream
:=
&
epFixedUpstream
{
statusCode
:
429
,
body
:
`{"error":"rate limited"}`
}
// rateLimitService is nil — must not panic
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
nil
}
account
:=
&
Account
{
ID
:
300
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Schedulable
:
true
,
Status
:
StatusActive
,
Concurrency
:
1
,
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
100
*
time
.
Millisecond
)
defer
cancel
()
p
:=
newRetryParams
(
account
,
upstream
,
func
(
_
context
.
Context
,
_
string
,
_
*
Account
,
_
int
,
_
http
.
Header
,
_
[]
byte
,
_
string
,
_
int64
,
_
string
,
_
bool
)
*
handleModelRateLimitResult
{
return
nil
})
p
.
ctx
=
ctx
// Should not panic; enters the default retry path (eventually times out)
result
,
err
:=
svc
.
antigravityRetryLoop
(
p
)
require
.
Nil
(
t
,
result
)
require
.
ErrorIs
(
t
,
err
,
context
.
DeadlineExceeded
)
require
.
GreaterOrEqual
(
t
,
upstream
.
calls
,
1
)
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
// ---------------------------------------------------------------------------
func
TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
(
t
*
testing
.
T
)
{
saveAndSetBaseURLs
(
t
)
upstream
:=
&
epFixedUpstream
{
statusCode
:
429
,
body
:
`{"error":"rate limited"}`
}
repo
:=
&
epAccountRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
rlSvc
}
// Plain OAuth account with no error policy configured
account
:=
&
Account
{
ID
:
400
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Schedulable
:
true
,
Status
:
StatusActive
,
Concurrency
:
1
,
}
var
handleErrorCount
int
p
:=
newRetryParams
(
account
,
upstream
,
func
(
_
context
.
Context
,
_
string
,
_
*
Account
,
_
int
,
_
http
.
Header
,
_
[]
byte
,
_
string
,
_
int64
,
_
string
,
_
bool
)
*
handleModelRateLimitResult
{
handleErrorCount
++
return
nil
})
result
,
err
:=
svc
.
antigravityRetryLoop
(
p
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
resp
)
defer
func
()
{
_
=
result
.
resp
.
Body
.
Close
()
}()
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
result
.
resp
.
StatusCode
)
require
.
Equal
(
t
,
antigravityMaxRetries
,
upstream
.
calls
,
"should exhaust all retries"
)
require
.
Equal
(
t
,
1
,
handleErrorCount
,
"handleError should be called once after retries exhausted"
)
}
backend/internal/service/error_policy_test.go
0 → 100644
View file @
d367d1cd
//go:build unit
package
service
import
(
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
// ---------------------------------------------------------------------------
func
TestCheckErrorPolicy
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
statusCode
int
body
[]
byte
expected
ErrorPolicyResult
}{
{
name
:
"no_policy_oauth_returns_none"
,
account
:
&
Account
{
ID
:
1
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
// no custom error codes, no temp rules
},
statusCode
:
500
,
body
:
[]
byte
(
`"error"`
),
expected
:
ErrorPolicyNone
,
},
{
name
:
"custom_error_codes_hit_returns_matched"
,
account
:
&
Account
{
ID
:
2
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
),
float64
(
500
)},
},
},
statusCode
:
500
,
body
:
[]
byte
(
`"error"`
),
expected
:
ErrorPolicyMatched
,
},
{
name
:
"custom_error_codes_miss_returns_skipped"
,
account
:
&
Account
{
ID
:
3
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
),
float64
(
500
)},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`"error"`
),
expected
:
ErrorPolicySkipped
,
},
{
name
:
"temp_unschedulable_hit_returns_temp_unscheduled"
,
account
:
&
Account
{
ID
:
4
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
"description"
:
"overloaded rule"
,
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded service`
),
expected
:
ErrorPolicyTempUnscheduled
,
},
{
name
:
"temp_unschedulable_body_miss_returns_none"
,
account
:
&
Account
{
ID
:
5
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
"description"
:
"overloaded rule"
,
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`random msg`
),
expected
:
ErrorPolicyNone
,
},
{
name
:
"custom_error_codes_override_temp_unschedulable"
,
account
:
&
Account
{
ID
:
6
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
503
)},
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
"description"
:
"overloaded rule"
,
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded`
),
expected
:
ErrorPolicyMatched
,
// custom codes take precedence
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
errorPolicyRepoStub
{}
svc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
result
:=
svc
.
CheckErrorPolicy
(
context
.
Background
(),
tt
.
account
,
tt
.
statusCode
,
tt
.
body
)
require
.
Equal
(
t
,
tt
.
expected
,
result
,
"unexpected ErrorPolicyResult"
)
})
}
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------
func
TestApplyErrorPolicy
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
statusCode
int
body
[]
byte
expectedHandled
bool
expectedSwitchErr
bool
// expect *AntigravityAccountSwitchError
handleErrorCalls
int
}{
{
name
:
"none_not_handled"
,
account
:
&
Account
{
ID
:
10
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
},
statusCode
:
500
,
body
:
[]
byte
(
`"error"`
),
expectedHandled
:
false
,
handleErrorCalls
:
0
,
},
{
name
:
"skipped_handled_no_handleError"
,
account
:
&
Account
{
ID
:
11
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
500
,
// not in custom codes
body
:
[]
byte
(
`"error"`
),
expectedHandled
:
true
,
handleErrorCalls
:
0
,
},
{
name
:
"matched_handled_calls_handleError"
,
account
:
&
Account
{
ID
:
12
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
500
)},
},
},
statusCode
:
500
,
body
:
[]
byte
(
`"error"`
),
expectedHandled
:
true
,
handleErrorCalls
:
1
,
},
{
name
:
"temp_unscheduled_returns_switch_error"
,
account
:
&
Account
{
ID
:
13
,
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded`
),
expectedHandled
:
true
,
expectedSwitchErr
:
true
,
handleErrorCalls
:
0
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
errorPolicyRepoStub
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
AntigravityGatewayService
{
rateLimitService
:
rlSvc
,
}
var
handleErrorCount
int
p
:=
antigravityRetryLoopParams
{
ctx
:
context
.
Background
(),
prefix
:
"[test]"
,
account
:
tt
.
account
,
handleError
:
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
requestedModel
string
,
groupID
int64
,
sessionHash
string
,
isStickySession
bool
)
*
handleModelRateLimitResult
{
handleErrorCount
++
return
nil
},
isStickySession
:
true
,
}
handled
,
retErr
:=
svc
.
applyErrorPolicy
(
p
,
tt
.
statusCode
,
http
.
Header
{},
tt
.
body
)
require
.
Equal
(
t
,
tt
.
expectedHandled
,
handled
,
"handled mismatch"
)
require
.
Equal
(
t
,
tt
.
handleErrorCalls
,
handleErrorCount
,
"handleError call count mismatch"
)
if
tt
.
expectedSwitchErr
{
var
switchErr
*
AntigravityAccountSwitchError
require
.
ErrorAs
(
t
,
retErr
,
&
switchErr
)
require
.
Equal
(
t
,
tt
.
account
.
ID
,
switchErr
.
OriginalAccountID
)
}
else
{
require
.
NoError
(
t
,
retErr
)
}
})
}
}
// ---------------------------------------------------------------------------
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
// ---------------------------------------------------------------------------
type
errorPolicyRepoStub
struct
{
mockAccountRepoForGemini
tempCalls
int
setErrCalls
int
lastErrorMsg
string
}
func
(
r
*
errorPolicyRepoStub
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
r
.
tempCalls
++
return
nil
}
func
(
r
*
errorPolicyRepoStub
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrCalls
++
r
.
lastErrorMsg
=
errorMsg
return
nil
}
backend/internal/service/gateway_multiplatform_test.go
View file @
d367d1cd
...
...
@@ -77,7 +77,12 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
func
(
m
*
mockAccountRepoForPlatform
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
interface
{})
([]
Account
,
error
)
{
func
(
m
*
mockAccountRepoForPlatform
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
...
...
@@ -145,9 +150,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func
(
m
*
mockAccountRepoForPlatform
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
...
...
@@ -219,22 +221,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return
nil
}
func
(
m
*
mockGatewayCacheForPlatform
)
IncrModelCallCount
(
ctx
context
.
Context
,
accountID
int64
,
model
string
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockGatewayCacheForPlatform
)
GetModelLoadBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
model
string
)
(
map
[
int64
]
*
ModelLoadInfo
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGatewayCacheForPlatform
)
FindGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
found
bool
)
{
return
""
,
0
,
false
}
func
(
m
*
mockGatewayCacheForPlatform
)
SaveGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
)
error
{
return
nil
}
type
mockGroupRepoForGateway
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
...
...
@@ -293,6 +279,10 @@ func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, g
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
...
...
backend/internal/service/gateway_request.go
View file @
d367d1cd
...
...
@@ -6,9 +6,19 @@ import (
"fmt"
"math"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type
SessionContext
struct
{
ClientIP
string
UserAgent
string
APIKeyID
int64
}
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
...
...
@@ -31,11 +41,13 @@ type ParsedRequest struct {
HasSystem
bool
// 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled
bool
// 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens
int
// max_tokens 值(用于探测请求拦截)
SessionContext
*
SessionContext
// 可选:请求上下文区分因子(nil 时行为不变)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func
ParseGatewayRequest
(
body
[]
byte
)
(
*
ParsedRequest
,
error
)
{
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
// 不同协议使用不同的 system/messages 字段名。
func
ParseGatewayRequest
(
body
[]
byte
,
protocol
string
)
(
*
ParsedRequest
,
error
)
{
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -64,6 +76,20 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed
.
MetadataUserID
=
userID
}
}
switch
protocol
{
case
domain
.
PlatformGemini
:
// Gemini 原生格式: systemInstruction.parts / contents
if
sysInst
,
ok
:=
req
[
"systemInstruction"
]
.
(
map
[
string
]
any
);
ok
{
if
parts
,
ok
:=
sysInst
[
"parts"
]
.
([]
any
);
ok
{
parsed
.
System
=
parts
}
}
if
contents
,
ok
:=
req
[
"contents"
]
.
([]
any
);
ok
{
parsed
.
Messages
=
contents
}
default
:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if
system
,
ok
:=
req
[
"system"
];
ok
{
...
...
@@ -73,6 +99,7 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
parsed
.
Messages
=
messages
}
}
// thinking: {type: "enabled"}
if
rawThinking
,
ok
:=
req
[
"thinking"
]
.
(
map
[
string
]
any
);
ok
{
...
...
backend/internal/service/gateway_request_test.go
View file @
d367d1cd
...
...
@@ -4,12 +4,13 @@ import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/stretchr/testify/require"
)
func
TestParseGatewayRequest
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"claude-3-7-sonnet"
,
parsed
.
Model
)
require
.
True
(
t
,
parsed
.
Stream
)
...
...
@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
func
TestParseGatewayRequest_ThinkingEnabled
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
parsed
.
Model
)
require
.
True
(
t
,
parsed
.
ThinkingEnabled
)
...
...
@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
func
TestParseGatewayRequest_MaxTokens
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model":"claude-haiku-4-5","max_tokens":1}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
parsed
.
MaxTokens
)
}
func
TestParseGatewayRequest_MaxTokensNonIntegralIgnored
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model":"claude-haiku-4-5","max_tokens":1.5}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
0
,
parsed
.
MaxTokens
)
}
func
TestParseGatewayRequest_SystemNull
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model":"claude-3","system":null}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
NoError
(
t
,
err
)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require
.
True
(
t
,
parsed
.
HasSystem
)
...
...
@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
func
TestParseGatewayRequest_InvalidModelType
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model":123}`
)
_
,
err
:=
ParseGatewayRequest
(
body
)
_
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
Error
(
t
,
err
)
}
func
TestParseGatewayRequest_InvalidStreamType
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"stream":"true"}`
)
_
,
err
:=
ParseGatewayRequest
(
body
)
_
,
err
:=
ParseGatewayRequest
(
body
,
""
)
require
.
Error
(
t
,
err
)
}
// ============ Gemini 原生格式解析测试 ============
func
TestParseGatewayRequest_GeminiContents
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
parsed
.
Messages
,
3
,
"should parse contents as Messages"
)
require
.
False
(
t
,
parsed
.
HasSystem
,
"Gemini format should not set HasSystem"
)
require
.
Nil
(
t
,
parsed
.
System
,
"no systemInstruction means nil System"
)
}
func
TestParseGatewayRequest_GeminiSystemInstruction
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]}
]
}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
parsed
.
System
,
"should parse systemInstruction.parts as System"
)
parts
,
ok
:=
parsed
.
System
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
parts
,
1
)
partMap
,
ok
:=
parts
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"You are a helpful assistant."
,
partMap
[
"text"
])
require
.
Len
(
t
,
parsed
.
Messages
,
1
)
}
func
TestParseGatewayRequest_GeminiWithModel
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{
"model": "gemini-2.5-pro",
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"gemini-2.5-pro"
,
parsed
.
Model
)
require
.
Len
(
t
,
parsed
.
Messages
,
1
)
}
func
TestParseGatewayRequest_GeminiIgnoresAnthropicFields
(
t
*
testing
.
T
)
{
// Gemini 格式下 system/messages 字段应被忽略
body
:=
[]
byte
(
`{
"system": "should be ignored",
"messages": [{"role": "user", "content": "ignored"}],
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
parsed
.
HasSystem
,
"Gemini protocol should not parse Anthropic system field"
)
require
.
Nil
(
t
,
parsed
.
System
,
"no systemInstruction = nil System"
)
require
.
Len
(
t
,
parsed
.
Messages
,
1
,
"should use contents, not messages"
)
}
func
TestParseGatewayRequest_GeminiEmptyContents
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"contents": []}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
parsed
.
Messages
)
}
func
TestParseGatewayRequest_GeminiNoContents
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"model": "gemini-2.5-flash"}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
parsed
.
Messages
)
require
.
Equal
(
t
,
"gemini-2.5-flash"
,
parsed
.
Model
)
}
func
TestParseGatewayRequest_AnthropicIgnoresGeminiFields
(
t
*
testing
.
T
)
{
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
body
:=
[]
byte
(
`{
"system": "real system",
"messages": [{"role": "user", "content": "real content"}],
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
"systemInstruction": {"parts": [{"text": "ignored"}]}
}`
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
domain
.
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
parsed
.
HasSystem
)
require
.
Equal
(
t
,
"real system"
,
parsed
.
System
)
require
.
Len
(
t
,
parsed
.
Messages
,
1
)
msg
,
ok
:=
parsed
.
Messages
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"real content"
,
msg
[
"content"
])
}
func
TestFilterThinkingBlocks
(
t
*
testing
.
T
)
{
containsThinkingBlock
:=
func
(
body
[]
byte
)
bool
{
var
req
map
[
string
]
any
...
...
backend/internal/service/gateway_service.go
View file @
d367d1cd
...
...
@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
...
...
@@ -17,6 +16,7 @@ import (
"os"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
...
...
@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
...
...
@@ -245,9 +246,6 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var
ErrClaudeCodeOnly
=
errors
.
New
(
"this group only allows Claude Code clients"
)
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var
ErrModelScopeNotSupported
=
errors
.
New
(
"model scope not supported by this group"
)
// allowedHeaders 白名单headers(参考CRS项目)
var
allowedHeaders
=
map
[
string
]
bool
{
"accept"
:
true
,
...
...
@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type
ModelLoadInfo
struct
{
CallCount
int64
// 当前分钟调用次数 / Call count in current minute
LastUsedAt
time
.
Time
// 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type
GatewayCache
interface
{
...
...
@@ -295,24 +286,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount
(
ctx
context
.
Context
,
accountID
int64
,
model
string
)
(
int64
,
error
)
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
model
string
)
(
map
[
int64
]
*
ModelLoadInfo
,
error
)
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息(uuid, accountID)
FindGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
found
bool
)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
)
error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...
...
@@ -323,21 +296,15 @@ func derefGroupID(groupID *int64) int64 {
return
*
groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const
stickySessionRateLimitThreshold
=
10
*
time
.
Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或
模型限流剩余时间超过 stickySessionRateLimitThreshold
时,返回 true。
// 或
请求的模型处于限流状态
时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// within temporary unschedulable period, or the requested model is rate-limited.
// This ensures subsequent requests won't continue using unavailable accounts.
func
shouldClearStickySession
(
account
*
Account
,
requestedModel
string
)
bool
{
if
account
==
nil
{
...
...
@@ -349,8 +316,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
if
account
.
TempUnschedulableUntil
!=
nil
&&
time
.
Now
()
.
Before
(
*
account
.
TempUnschedulableUntil
)
{
return
true
}
// 检查模型限流和 scope 限流,
只在超过阈值时
清除粘性会话
if
remaining
:=
account
.
GetRateLimitRemainingTimeWithContext
(
context
.
Background
(),
requestedModel
);
remaining
>
stickySessionRateLimitThreshold
{
// 检查模型限流和 scope 限流,
有限流即
清除粘性会话
if
remaining
:=
account
.
GetRateLimitRemainingTimeWithContext
(
context
.
Background
(),
requestedModel
);
remaining
>
0
{
return
true
}
return
false
...
...
@@ -417,6 +384,7 @@ type GatewayService struct {
userSubRepo
UserSubscriptionRepository
userGroupRateRepo
UserGroupRateRepository
cache
GatewayCache
digestStore
*
DigestSessionStore
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
billingService
*
BillingService
...
...
@@ -450,6 +418,7 @@ func NewGatewayService(
deferredService
*
DeferredService
,
claudeTokenProvider
*
ClaudeTokenProvider
,
sessionLimitCache
SessionLimitCache
,
digestStore
*
DigestSessionStore
,
)
*
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
...
...
@@ -459,6 +428,7 @@ func NewGatewayService(
userSubRepo
:
userSubRepo
,
userGroupRateRepo
:
userGroupRateRepo
,
cache
:
cache
,
digestStore
:
digestStore
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
...
...
@@ -492,23 +462,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return
s
.
hashContent
(
cacheableContent
)
}
// 3. Fallback: 使用 system 内容
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var
combined
strings
.
Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if
parsed
.
SessionContext
!=
nil
{
_
,
_
=
combined
.
WriteString
(
parsed
.
SessionContext
.
ClientIP
)
_
,
_
=
combined
.
WriteString
(
":"
)
_
,
_
=
combined
.
WriteString
(
parsed
.
SessionContext
.
UserAgent
)
_
,
_
=
combined
.
WriteString
(
":"
)
_
,
_
=
combined
.
WriteString
(
strconv
.
FormatInt
(
parsed
.
SessionContext
.
APIKeyID
,
10
))
_
,
_
=
combined
.
WriteString
(
"|"
)
}
if
parsed
.
System
!=
nil
{
systemText
:=
s
.
extractTextFromSystem
(
parsed
.
System
)
if
systemText
!=
""
{
return
s
.
hashContent
(
systemText
)
_
,
_
=
combined
.
WriteString
(
systemText
)
}
}
for
_
,
msg
:=
range
parsed
.
Messages
{
if
m
,
ok
:=
msg
.
(
map
[
string
]
any
);
ok
{
if
content
,
exists
:=
m
[
"content"
];
exists
{
// Anthropic: messages[].content
if
msgText
:=
s
.
extractTextFromContent
(
content
);
msgText
!=
""
{
_
,
_
=
combined
.
WriteString
(
msgText
)
}
}
else
if
parts
,
ok
:=
m
[
"parts"
]
.
([]
any
);
ok
{
// Gemini: contents[].parts[].text
for
_
,
part
:=
range
parts
{
if
partMap
,
ok
:=
part
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
partMap
[
"text"
]
.
(
string
);
ok
{
_
,
_
=
combined
.
WriteString
(
text
)
}
}
// 4. 最后 fallback: 使用第一条消息
if
len
(
parsed
.
Messages
)
>
0
{
if
firstMsg
,
ok
:=
parsed
.
Messages
[
0
]
.
(
map
[
string
]
any
);
ok
{
msgText
:=
s
.
extractTextFromContent
(
firstMsg
[
"content"
])
if
msgText
!=
""
{
return
s
.
hashContent
(
msgText
)
}
}
}
}
if
combined
.
Len
()
>
0
{
return
s
.
hashContent
(
combined
.
String
())
}
return
""
}
...
...
@@ -536,19 +528,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息(uuid, accountID)
func
(
s
*
GatewayService
)
FindGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
found
bool
)
{
if
digestChain
==
""
||
s
.
cach
e
==
nil
{
return
""
,
0
,
false
func
(
s
*
GatewayService
)
FindGeminiSession
(
_
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
matchedChain
string
,
found
bool
)
{
if
digestChain
==
""
||
s
.
digestStor
e
==
nil
{
return
""
,
0
,
""
,
false
}
return
s
.
cache
.
FindGeminiSession
(
ctx
,
groupID
,
prefixHash
,
digestChain
)
return
s
.
digestStore
.
Find
(
groupID
,
prefixHash
,
digestChain
)
}
// SaveGeminiSession 保存 Gemini 会话
func
(
s
*
GatewayService
)
SaveGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
)
error
{
if
digestChain
==
""
||
s
.
cach
e
==
nil
{
// SaveGeminiSession 保存 Gemini 会话
。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
func
(
s
*
GatewayService
)
SaveGeminiSession
(
_
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
,
oldDigestChain
string
)
error
{
if
digestChain
==
""
||
s
.
digestStor
e
==
nil
{
return
nil
}
return
s
.
cache
.
SaveGeminiSession
(
ctx
,
groupID
,
prefixHash
,
digestChain
,
uuid
,
accountID
)
s
.
digestStore
.
Save
(
groupID
,
prefixHash
,
digestChain
,
uuid
,
accountID
,
oldDigestChain
)
return
nil
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func
(
s
*
GatewayService
)
FindAnthropicSession
(
_
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
matchedChain
string
,
found
bool
)
{
if
digestChain
==
""
||
s
.
digestStore
==
nil
{
return
""
,
0
,
""
,
false
}
return
s
.
digestStore
.
Find
(
groupID
,
prefixHash
,
digestChain
)
}
// SaveAnthropicSession 保存 Anthropic 会话
func
(
s
*
GatewayService
)
SaveAnthropicSession
(
_
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
,
oldDigestChain
string
)
error
{
if
digestChain
==
""
||
s
.
digestStore
==
nil
{
return
nil
}
s
.
digestStore
.
Save
(
groupID
,
prefixHash
,
digestChain
,
uuid
,
accountID
,
oldDigestChain
)
return
nil
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
...
...
@@ -633,8 +643,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
}
func
(
s
*
GatewayService
)
hashContent
(
content
string
)
string
{
h
ash
:=
sha256
.
Sum256
([]
byte
(
content
)
)
return
hex
.
EncodeToStr
in
g
(
h
ash
[
:
16
])
// 32字符
h
:=
xxhash
.
Sum64String
(
content
)
return
strconv
.
FormatU
in
t
(
h
,
36
)
}
// replaceModelInBody 替换请求体中的model字段
...
...
@@ -993,13 +1003,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -1114,7 +1117,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
.
ReleaseFunc
()
// 释放槽位
// 继续到负载感知选择
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
}
...
...
@@ -1194,6 +1196,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
a
.
account
.
LastUsedAt
.
Before
(
*
b
.
account
.
LastUsedAt
)
}
})
shuffleWithinSortGroups
(
routingAvailable
)
// 4. 尝试获取槽位
for
_
,
item
:=
range
routingAvailable
{
...
...
@@ -1268,7 +1271,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
...
...
@@ -1348,10 +1350,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
result
,
nil
}
}
else
{
// Antigravity 平台:获取模型负载信息
var
modelLoadMap
map
[
int64
]
*
ModelLoadInfo
isAntigravity
:=
platform
==
PlatformAntigravity
var
available
[]
accountWithLoad
for
_
,
acc
:=
range
candidates
{
loadInfo
:=
loadMap
[
acc
.
ID
]
...
...
@@ -1366,71 +1364,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if
isAntigravity
&&
requestedModel
!=
""
&&
s
.
cache
!=
nil
&&
len
(
available
)
>
0
{
modelLoadMap
=
make
(
map
[
int64
]
*
ModelLoadInfo
,
len
(
available
))
modelToAccountIDs
:=
make
(
map
[
string
][]
int64
)
for
_
,
item
:=
range
available
{
mappedModel
:=
mapAntigravityModel
(
item
.
account
,
requestedModel
)
if
mappedModel
==
""
{
continue
}
modelToAccountIDs
[
mappedModel
]
=
append
(
modelToAccountIDs
[
mappedModel
],
item
.
account
.
ID
)
}
for
model
,
ids
:=
range
modelToAccountIDs
{
batch
,
err
:=
s
.
cache
.
GetModelLoadBatch
(
ctx
,
ids
,
model
)
if
err
!=
nil
{
continue
}
for
id
,
info
:=
range
batch
{
modelLoadMap
[
id
]
=
info
}
}
if
len
(
modelLoadMap
)
==
0
{
modelLoadMap
=
nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if
isAntigravity
{
for
len
(
available
)
>
0
{
// 1. 取优先级最小的集合(硬过滤)
candidates
:=
filterByMinPriority
(
available
)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected
:=
selectByCallCount
(
candidates
,
modelLoadMap
,
preferOAuth
)
if
selected
==
nil
{
break
}
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
selected
.
account
.
ID
,
selected
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
selected
.
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
}
else
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
account
.
ID
,
stickySessionTTL
)
}
return
&
AccountSelectionResult
{
Account
:
selected
.
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
// 移除已尝试的账号,重新选择
selectedID
:=
selected
.
account
.
ID
newAvailable
:=
make
([]
accountWithLoad
,
0
,
len
(
available
)
-
1
)
for
_
,
acc
:=
range
available
{
if
acc
.
account
.
ID
!=
selectedID
{
newAvailable
=
append
(
newAvailable
,
acc
)
}
}
available
=
newAvailable
}
}
else
{
// 分层过滤选择:优先级 → 负载率 → LRU
for
len
(
available
)
>
0
{
// 1. 取优先级最小的集合
candidates
:=
filterByMinPriority
(
available
)
...
...
@@ -1470,7 +1404,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
available
=
newAvailable
}
}
}
// ============ Layer 3: 兜底排队 ============
s
.
sortCandidatesForFallback
(
candidates
,
preferOAuth
,
cfg
.
FallbackSelectionMode
)
...
...
@@ -2004,87 +1937,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
return
a
.
LastUsedAt
.
Before
(
*
b
.
LastUsedAt
)
}
})
shuffleWithinPriorityAndLastUsed
(
accounts
)
}
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func
selectByCallCount
(
accounts
[]
accountWithLoad
,
modelLoadMap
map
[
int64
]
*
ModelLoadInfo
,
preferOAuth
bool
)
*
accountWithLoad
{
if
len
(
accounts
)
==
0
{
return
nil
}
if
len
(
accounts
)
==
1
{
return
&
accounts
[
0
]
}
// 如果没有负载信息,回退到 LRU
if
modelLoadMap
==
nil
{
return
selectByLRU
(
accounts
,
preferOAuth
)
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
func
shuffleWithinSortGroups
(
accounts
[]
accountWithLoad
)
{
if
len
(
accounts
)
<=
1
{
return
}
// 1. 计算平均调用次数(用于新账号冷启动)
var
totalCallCount
int64
var
countWithCalls
int
for
_
,
acc
:=
range
accounts
{
if
info
:=
modelLoadMap
[
acc
.
account
.
ID
];
info
!=
nil
&&
info
.
CallCount
>
0
{
totalCallCount
+=
info
.
CallCount
countWithCalls
++
i
:=
0
for
i
<
len
(
accounts
)
{
j
:=
i
+
1
for
j
<
len
(
accounts
)
&&
sameAccountWithLoadGroup
(
accounts
[
i
],
accounts
[
j
])
{
j
++
}
if
j
-
i
>
1
{
mathrand
.
Shuffle
(
j
-
i
,
func
(
a
,
b
int
)
{
accounts
[
i
+
a
],
accounts
[
i
+
b
]
=
accounts
[
i
+
b
],
accounts
[
i
+
a
]
})
}
var
avgCallCount
int64
if
countWithCalls
>
0
{
avgCallCount
=
totalCallCount
/
int64
(
countWithCalls
)
i
=
j
}
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount
:=
func
(
acc
accountWithLoad
)
int64
{
if
acc
.
account
==
nil
{
return
0
}
info
:=
modelLoadMap
[
acc
.
account
.
ID
]
if
info
==
nil
||
info
.
CallCount
==
0
{
return
avgCallCount
// 新账号使用平均值
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
func
sameAccountWithLoadGroup
(
a
,
b
accountWithLoad
)
bool
{
if
a
.
account
.
Priority
!=
b
.
account
.
Priority
{
return
false
}
return
info
.
CallCount
if
a
.
loadInfo
.
LoadRate
!=
b
.
loadInfo
.
LoadRate
{
return
false
}
return
sameLastUsedAt
(
a
.
account
.
LastUsedAt
,
b
.
account
.
LastUsedAt
)
}
// 3. 找到最小调用次数
minCount
:=
getEffectiveCallCount
(
accounts
[
0
])
for
_
,
acc
:=
range
accounts
[
1
:
]
{
if
c
:=
getEffectiveCallCount
(
acc
);
c
<
minCount
{
minCount
=
c
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
func
shuffleWithinPriorityAndLastUsed
(
accounts
[]
*
Account
)
{
if
len
(
accounts
)
<=
1
{
return
}
i
:=
0
for
i
<
len
(
accounts
)
{
j
:=
i
+
1
for
j
<
len
(
accounts
)
&&
sameAccountGroup
(
accounts
[
i
],
accounts
[
j
])
{
j
++
}
// 4. 收集所有具有最小调用次数的账号
var
candidateIdxs
[]
int
for
i
,
acc
:=
range
accounts
{
if
getEffectiveCallCount
(
acc
)
==
minCount
{
candidateIdxs
=
append
(
candidateIdxs
,
i
)
if
j
-
i
>
1
{
mathrand
.
Shuffle
(
j
-
i
,
func
(
a
,
b
int
)
{
accounts
[
i
+
a
],
accounts
[
i
+
b
]
=
accounts
[
i
+
b
],
accounts
[
i
+
a
]
})
}
i
=
j
}
}
// 5. 如果只有一个候选,直接返回
if
len
(
candidateIdxs
)
==
1
{
return
&
accounts
[
candidateIdxs
[
0
]]
// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
func
sameAccountGroup
(
a
,
b
*
Account
)
bool
{
if
a
.
Priority
!=
b
.
Priority
{
return
false
}
return
sameLastUsedAt
(
a
.
LastUsedAt
,
b
.
LastUsedAt
)
}
// 6. preferOAuth 处理
if
preferOAuth
{
var
oauthIdxs
[]
int
for
_
,
idx
:=
range
candidateIdxs
{
if
accounts
[
idx
]
.
account
.
Type
==
AccountTypeOAuth
{
oauthIdxs
=
append
(
oauthIdxs
,
idx
)
}
}
if
len
(
oauthIdxs
)
>
0
{
candidateIdxs
=
oauthIdxs
}
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
func
sameLastUsedAt
(
a
,
b
*
time
.
Time
)
bool
{
switch
{
case
a
==
nil
&&
b
==
nil
:
return
true
case
a
==
nil
||
b
==
nil
:
return
false
default
:
return
a
.
Unix
()
==
b
.
Unix
()
}
// 7. 随机选择
return
&
accounts
[
candidateIdxs
[
mathrand
.
Intn
(
len
(
candidateIdxs
))]]
}
// sortCandidatesForFallback 根据配置选择排序策略
...
...
@@ -2139,13 +2064,6 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
preferOAuth
:=
platform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
...
...
@@ -2173,9 +2091,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
account
.
IsSchedulableForModelWithContext
(
ctx
,
requestedModel
)
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
...
...
@@ -2276,9 +2191,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
account
.
IsSchedulableForModelWithContext
(
ctx
,
requestedModel
)
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
return
account
,
nil
}
}
...
...
@@ -2387,9 +2299,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
account
.
IsSchedulableForModelWithContext
(
ctx
,
requestedModel
)
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
...
...
@@ -2492,9 +2401,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccountWithContext
(
ctx
,
account
,
requestedModel
))
&&
account
.
IsSchedulableForModelWithContext
(
ctx
,
requestedModel
)
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
return
account
,
nil
}
}
...
...
@@ -5185,27 +5091,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return
normalized
,
nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func
(
s
*
GatewayService
)
checkAntigravityModelScope
(
ctx
context
.
Context
,
groupID
int64
,
requestedModel
string
)
error
{
scope
,
ok
:=
ResolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
nil
// 无法解析 scope,跳过检查
}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
// 查询失败时放行
}
if
group
==
nil
{
return
nil
// 分组不存在时放行
}
if
!
IsScopeSupported
(
group
.
SupportedModelScopes
,
scope
)
{
return
ErrModelScopeNotSupported
}
return
nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gateway_service_benchmark_test.go
View file @
d367d1cd
...
...
@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
b
.
ReportAllocs
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
parsed
,
err
:=
ParseGatewayRequest
(
body
)
parsed
,
err
:=
ParseGatewayRequest
(
body
,
""
)
if
err
!=
nil
{
b
.
Fatalf
(
"解析请求失败: %v"
,
err
)
}
...
...
backend/internal/service/gemini_error_policy_test.go
0 → 100644
View file @
d367d1cd
//go:build unit
package
service
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func
TestShouldFailoverGeminiUpstreamError
(
t
*
testing
.
T
)
{
svc
:=
&
GeminiMessagesCompatService
{}
tests
:=
[]
struct
{
name
string
statusCode
int
expected
bool
}{
{
"401_failover"
,
401
,
true
},
{
"403_failover"
,
403
,
true
},
{
"429_failover"
,
429
,
true
},
{
"529_failover"
,
529
,
true
},
{
"500_failover"
,
500
,
true
},
{
"502_failover"
,
502
,
true
},
{
"503_failover"
,
503
,
true
},
{
"400_no_failover"
,
400
,
false
},
{
"404_no_failover"
,
404
,
false
},
{
"422_no_failover"
,
422
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
shouldFailoverGeminiUpstreamError
(
tt
.
statusCode
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func
TestCheckErrorPolicy_GeminiAccounts
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
statusCode
int
body
[]
byte
expected
ErrorPolicyResult
}{
{
name
:
"gemini_apikey_custom_codes_hit"
,
account
:
&
Account
{
ID
:
100
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
),
float64
(
500
)},
},
},
statusCode
:
429
,
body
:
[]
byte
(
`{"error":"rate limited"}`
),
expected
:
ErrorPolicyMatched
,
},
{
name
:
"gemini_apikey_custom_codes_miss"
,
account
:
&
Account
{
ID
:
101
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
500
,
body
:
[]
byte
(
`{"error":"internal"}`
),
expected
:
ErrorPolicySkipped
,
},
{
name
:
"gemini_apikey_no_custom_codes_returns_none"
,
account
:
&
Account
{
ID
:
102
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
statusCode
:
500
,
body
:
[]
byte
(
`{"error":"internal"}`
),
expected
:
ErrorPolicyNone
,
},
{
name
:
"gemini_apikey_temp_unschedulable_hit"
,
account
:
&
Account
{
ID
:
103
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded service`
),
expected
:
ErrorPolicyTempUnscheduled
,
},
{
name
:
"gemini_custom_codes_override_temp_unschedulable"
,
account
:
&
Account
{
ID
:
104
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
503
)},
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded`
),
expected
:
ErrorPolicyMatched
,
// custom codes take precedence
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
errorPolicyRepoStub
{}
svc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
result
:=
svc
.
CheckErrorPolicy
(
context
.
Background
(),
tt
.
account
,
tt
.
statusCode
,
tt
.
body
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func
TestGeminiErrorPolicyIntegration
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
account
*
Account
statusCode
int
respBody
[]
byte
expectFailover
bool
// expect UpstreamFailoverError
expectHandleError
bool
// expect handleGeminiUpstreamError to be called
expectShouldFailover
bool
// for None path, whether shouldFailover triggers
}{
{
name
:
"custom_codes_matched_429_failover"
,
account
:
&
Account
{
ID
:
200
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
429
,
respBody
:
[]
byte
(
`{"error":"rate limited"}`
),
expectFailover
:
true
,
expectHandleError
:
true
,
},
{
name
:
"custom_codes_skipped_500_no_failover"
,
account
:
&
Account
{
ID
:
201
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
500
,
respBody
:
[]
byte
(
`{"error":"internal"}`
),
expectFailover
:
false
,
expectHandleError
:
false
,
},
{
name
:
"temp_unschedulable_matched_failover"
,
account
:
&
Account
{
ID
:
202
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
respBody
:
[]
byte
(
`overloaded`
),
expectFailover
:
true
,
expectHandleError
:
true
,
},
{
name
:
"no_policy_429_failover_via_shouldFailover"
,
account
:
&
Account
{
ID
:
203
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
statusCode
:
429
,
respBody
:
[]
byte
(
`{"error":"rate limited"}`
),
expectFailover
:
true
,
expectHandleError
:
true
,
expectShouldFailover
:
true
,
},
{
name
:
"no_policy_400_no_failover"
,
account
:
&
Account
{
ID
:
204
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
statusCode
:
400
,
respBody
:
[]
byte
(
`{"error":"bad request"}`
),
expectFailover
:
false
,
expectHandleError
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
geminiErrorPolicyRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
rateLimitService
:
rlSvc
,
}
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var
handleErrorCalled
bool
var
gotFailover
bool
ctx
:=
context
.
Background
()
statusCode
:=
tt
.
statusCode
respBody
:=
tt
.
respBody
account
:=
tt
.
account
headers
:=
http
.
Header
{}
if
svc
.
rateLimitService
!=
nil
{
switch
svc
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
statusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover
=
false
handleErrorCalled
=
false
goto
verify
case
ErrorPolicyMatched
,
ErrorPolicyTempUnscheduled
:
svc
.
handleGeminiUpstreamError
(
ctx
,
account
,
statusCode
,
headers
,
respBody
)
handleErrorCalled
=
true
gotFailover
=
true
goto
verify
}
}
// ErrorPolicyNone → original logic
svc
.
handleGeminiUpstreamError
(
ctx
,
account
,
statusCode
,
headers
,
respBody
)
handleErrorCalled
=
true
if
svc
.
shouldFailoverGeminiUpstreamError
(
statusCode
)
{
gotFailover
=
true
}
verify
:
require
.
Equal
(
t
,
tt
.
expectFailover
,
gotFailover
,
"failover mismatch"
)
require
.
Equal
(
t
,
tt
.
expectHandleError
,
handleErrorCalled
,
"handleGeminiUpstreamError call mismatch"
)
if
tt
.
expectShouldFailover
{
require
.
True
(
t
,
svc
.
shouldFailoverGeminiUpstreamError
(
statusCode
),
"shouldFailoverGeminiUpstreamError should return true for status %d"
,
statusCode
)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func
TestGeminiErrorPolicy_NilRateLimitService
(
t
*
testing
.
T
)
{
svc
:=
&
GeminiMessagesCompatService
{
rateLimitService
:
nil
,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx
:=
context
.
Background
()
account
:=
&
Account
{
ID
:
300
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if
svc
.
rateLimitService
!=
nil
{
t
.
Fatal
(
"rateLimitService should be nil for this test"
)
}
// shouldFailoverGeminiUpstreamError still works
require
.
True
(
t
,
svc
.
shouldFailoverGeminiUpstreamError
(
429
))
require
.
False
(
t
,
svc
.
shouldFailoverGeminiUpstreamError
(
400
))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require
.
NotPanics
(
t
,
func
()
{
svc
.
handleGeminiUpstreamError
(
ctx
,
account
,
500
,
http
.
Header
{},
[]
byte
(
`error`
))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type
geminiErrorPolicyRepo
struct
{
mockAccountRepoForGemini
setErrorCalls
int
setRateLimitedCalls
int
setTempCalls
int
}
func
(
r
*
geminiErrorPolicyRepo
)
SetError
(
_
context
.
Context
,
_
int64
,
_
string
)
error
{
r
.
setErrorCalls
++
return
nil
}
func
(
r
*
geminiErrorPolicyRepo
)
SetRateLimited
(
_
context
.
Context
,
_
int64
,
_
time
.
Time
)
error
{
r
.
setRateLimitedCalls
++
return
nil
}
func
(
r
*
geminiErrorPolicyRepo
)
SetTempUnschedulable
(
_
context
.
Context
,
_
int64
,
_
time
.
Time
,
_
string
)
error
{
r
.
setTempCalls
++
return
nil
}
backend/internal/service/gemini_messages_compat_service.go
View file @
d367d1cd
...
...
@@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
baseURL
:=
account
.
GetGeminiBaseURL
(
geminicli
.
AIStudioBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
...
...
@@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
baseURL
:=
account
.
GetGeminiBaseURL
(
geminicli
.
AIStudioBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
...
...
@@ -837,12 +831,17 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
tempMatched
:=
false
// 统一错误策略:自定义错误码 + 临时不可调度
if
s
.
rateLimitService
!=
nil
{
tempMatched
=
s
.
rateLimitService
.
HandleTempUnschedulable
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
switch
s
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
return
nil
,
s
.
writeGeminiMappedError
(
c
,
account
,
resp
.
StatusCode
,
upstreamReqID
,
respBody
)
case
ErrorPolicyMatched
,
ErrorPolicyTempUnscheduled
:
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
tempMatched
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
...
...
@@ -869,6 +868,10 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
}
}
}
// ErrorPolicyNone → 原有逻辑
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
...
...
@@ -1026,10 +1029,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
baseURL
:=
account
.
GetGeminiBaseURL
(
geminicli
.
AIStudioBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
...
...
@@ -1097,10 +1097,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
baseURL
:=
account
.
GetGeminiBaseURL
(
geminicli
.
AIStudioBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
...
...
@@ -1261,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
tempMatched
:=
false
if
s
.
rateLimitService
!=
nil
{
tempMatched
=
s
.
rateLimitService
.
HandleTempUnschedulable
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
}
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if
action
==
"countTokens"
&&
isOAuth
&&
isGeminiInsufficientScope
(
resp
.
Header
,
respBody
)
{
estimated
:=
estimateGeminiCountTokens
(
body
)
c
.
JSON
(
http
.
StatusOK
,
map
[
string
]
any
{
"totalTokens"
:
estimated
})
...
...
@@ -1282,7 +1274,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
},
nil
}
if
tempMatched
{
// 统一错误策略:自定义错误码 + 临时不可调度
if
s
.
rateLimitService
!=
nil
{
switch
s
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
respBody
=
unwrapIfNeeded
(
isOAuth
,
respBody
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
}
c
.
Data
(
resp
.
StatusCode
,
contentType
,
respBody
)
return
nil
,
fmt
.
Errorf
(
"gemini upstream error: %d (skipped by error policy)"
,
resp
.
StatusCode
)
case
ErrorPolicyMatched
,
ErrorPolicyTempUnscheduled
:
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
...
...
@@ -1306,6 +1310,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
}
}
}
// ErrorPolicyNone → 原有逻辑
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
...
...
@@ -2420,10 +2428,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return
nil
,
errors
.
New
(
"invalid path"
)
}
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
baseURL
:=
account
.
GetGeminiBaseURL
(
geminicli
.
AIStudioBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment