Commit d367d1cd authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'main' into test-sora

parents d7011163 3c46f7d2
package service
import (
"encoding/json"
"strings"
"time"
)
// Anthropic 会话 Fallback 相关常量
const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
func AnthropicSessionTTL() time.Duration {
return anthropicSessionTTLSeconds * time.Second
}
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
// s = system, u = user, a = assistant
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var parts []string
// 1. system prompt
if parsed.System != nil {
systemData, _ := json.Marshal(parsed.System)
if len(systemData) > 0 && string(systemData) != "null" {
parts = append(parts, "s:"+shortHash(systemData))
}
}
// 2. messages
for _, msg := range parsed.Messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
prefix := rolePrefix(role)
content := msgMap["content"]
contentData, _ := json.Marshal(content)
parts = append(parts, prefix+":"+shortHash(contentData))
}
return strings.Join(parts, "-")
}
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
func rolePrefix(role string) string {
switch role {
case "assistant":
return "a"
default:
return "u"
}
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
}
package service
import (
"strings"
"testing"
)
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
result := BuildAnthropicDigestChain(nil)
if result != "" {
t.Errorf("expected empty string for nil request, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{},
}
result := BuildAnthropicDigestChain(parsed)
if result != "" {
t.Errorf("expected empty string for empty messages, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "a:") {
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
parsed := &ParsedRequest{
System: "You are a helpful assistant",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "u:") {
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
parsed := &ParsedRequest{
System: []any{
map[string]any{"type": "text", "text": "You are a helpful assistant"},
},
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
// 核心测试:验证对话增长时链的前缀关系
// 上一轮的完整链一定是下一轮链的前缀
system := "You are a helpful assistant"
// 第 1 轮: system + user
round1 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(round1)
// 第 2 轮: system + user + assistant + user
round2 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
},
}
chain2 := BuildAnthropicDigestChain(round2)
// 第 3 轮: system + user + assistant + user + assistant + user
round3 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
map[string]any{"role": "assistant", "content": "I'm doing well"},
map[string]any{"role": "user", "content": "great"},
},
}
chain3 := BuildAnthropicDigestChain(round3)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
t.Logf("Chain3: %s", chain3)
// chain1 是 chain2 的前缀
if !strings.HasPrefix(chain2, chain1) {
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
}
// chain2 是 chain3 的前缀
if !strings.HasPrefix(chain3, chain2) {
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
}
// chain1 也是 chain3 的前缀(传递性)
if !strings.HasPrefix(chain3, chain1) {
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
}
}
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
System: "System A",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
parsed2 := &ParsedRequest{
System: "System B",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different system prompts should produce different chains")
}
// 但 user 部分的 hash 应该相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[1] != parts2[1] {
t.Error("Same user message should produce same hash regardless of system")
}
}
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
map[string]any{"role": "user", "content": "next"},
},
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
map[string]any{"role": "user", "content": "next"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different content should produce different chains")
}
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
// 第一个 user message hash 应该相同
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
// assistant reply hash 应该不同
if parts1[1] == parts2[1] {
t.Error("Assistant reply hash should differ")
}
}
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
System: "test system",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi"},
},
}
chain1 := BuildAnthropicDigestChain(parsed)
chain2 := BuildAnthropicDigestChain(parsed)
if chain1 != chain2 {
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "anthropic:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "anthropic:digest:12345678:abcdefgh",
},
{
name: "short values",
prefixHash: "abc",
uuid: "xyz",
want: "anthropic:digest:abc:xyz",
},
{
name: "empty values",
prefixHash: "",
uuid: "",
want: "anthropic:digest::",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证不同 uuid 产生不同 sessionKey
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestAnthropicSessionTTL(t *testing.T) {
ttl := AnthropicSessionTTL()
if ttl.Seconds() != 300 {
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
}
}
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
// 测试 content 为 content blocks 数组的情况
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe this image"},
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
},
},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
......@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log"
"log/slog"
mathrand "math/rand"
"net"
"net/http"
......@@ -35,7 +36,7 @@ const (
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
antigravityRateLimitThreshold = 7 * time.Second
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// Google RPC 状态和类型常量
......@@ -100,12 +101,11 @@ type antigravityRetryLoopParams struct {
accessToken string
action string
body []byte
quotaScope AntigravityQuotaScope
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
accountRepo AccountRepository // 用于智能重试的模型级别限流
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult
requestedModel string // 用于限流检查的原始请求模型
isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断)
groupID int64 // 用于模型级限流时清除粘性会话
......@@ -148,13 +148,17 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel {
log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)",
p.prefix, resp.StatusCode, modelName, p.account.ID)
rateLimitDuration := waitDuration
if rateLimitDuration <= 0 {
rateLimitDuration = antigravityDefaultRateLimitDuration
}
log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)",
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200))
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
resetAt := time.Now().Add(rateLimitDuration)
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID)
} else {
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
......@@ -190,7 +194,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
......@@ -233,20 +237,33 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
// 所有重试都失败,限流当前模型并切换账号
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID)
rateLimitDuration := waitDuration
if rateLimitDuration <= 0 {
rateLimitDuration = antigravityDefaultRateLimitDuration
}
retryBody := lastRetryBody
if retryBody == nil {
retryBody = respBody
}
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
resetAt := time.Now().Add(rateLimitDuration)
if p.accountRepo != nil && modelName != "" {
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
} else {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration)
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration)
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
}
// 清除粘性会话绑定,避免下次请求仍命中限流账号
if s.cache != nil && p.sessionHash != "" {
_ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
}
// 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
......@@ -264,22 +281,11 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:如果账号已限流,根据剩余时间决定等待或切换
// 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
if remaining < antigravityRateLimitThreshold {
// 限流剩余时间较短,等待后继续
log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
select {
case <-p.ctx.Done():
return nil, p.ctx.Err()
case <-time.After(remaining):
}
} else {
// 限流剩余时间较长,返回账号切换信号
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID)
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
......@@ -287,7 +293,6 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
}
}
}
}
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
......@@ -360,11 +365,26 @@ urlFallbackLoop:
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
// 统一处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// ★ 统一入口:自定义错误码 + 临时不可调度
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil {
return nil, policyErr
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
// 尝试智能重试处理(OAuth 账号专用)
smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs)
switch smartResult.action {
......@@ -406,7 +426,7 @@ urlFallbackLoop:
}
// 重试用尽,标记账户限流
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
resp = &http.Response{
StatusCode: resp.StatusCode,
......@@ -416,11 +436,8 @@ urlFallbackLoop:
break urlFallbackLoop
}
// 其他可重试错误(不包括 429 和 503,因为上面已处理)
if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// 其他可重试错误(500/502/504/529,不包括 429 和 503)
if shouldRetryAntigravityError(resp.StatusCode) {
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
......@@ -441,6 +458,9 @@ urlFallbackLoop:
}
continue
}
}
// 其他 4xx 错误或重试用尽,直接返回
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
......@@ -449,6 +469,7 @@ urlFallbackLoop:
break urlFallbackLoop
}
// 成功响应(< 400)
break urlFallbackLoop
}
}
......@@ -581,6 +602,31 @@ func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string {
return truncateString(string(body), maxBytes)
}
// checkErrorPolicy nil 安全的包装
func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult {
if s.rateLimitService == nil {
return ErrorPolicyNone
}
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
}
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
case ErrorPolicySkipped:
return true, nil
case ErrorPolicyMatched:
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
return true, nil
case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
}
return false, nil
}
// mapAntigravityModel 获取映射后的模型名
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping)
// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号
......@@ -650,6 +696,7 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 获取 token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
......@@ -964,8 +1011,24 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
}
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
// 上游透传账号直接转发,不走 OAuth token 刷新
if account.Type == AccountTypeUpstream {
return s.ForwardUpstream(ctx, c, account, body)
}
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
......@@ -983,11 +1046,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if mappedModel == "" {
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
}
loadModel := mappedModel
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 获取 access_token
if s.tokenProvider == nil {
......@@ -1022,11 +1083,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if s.cache != nil {
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel)
}
// 执行带重试的请求
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
......@@ -1036,7 +1092,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
accessToken: accessToken,
action: action,
body: geminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
......@@ -1117,7 +1172,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
......@@ -1228,7 +1282,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
......@@ -1258,6 +1312,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
......@@ -1267,6 +1322,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
......@@ -1285,6 +1341,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
}
......@@ -1582,211 +1639,20 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return changed, nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude, 0, "", false)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
usage := &ClaudeUsage{}
var firstTokenMs *int
var firstTokenRecorded bool
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
// 记录首 token 时间
if !firstTokenRecorded && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
firstTokenRecorded = true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if bytes.HasPrefix(line, []byte("data: ")) {
dataStr := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]any
if json.Unmarshal(dataStr, &event) == nil {
if u, ok := event["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
}
}
// 透传行
_, _ = c.Writer.Write(line)
_, _ = c.Writer.Write([]byte("\n"))
c.Writer.Flush()
}
return usage, firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}
// ForwardGemini 转发 Gemini 协议请求
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
......@@ -1799,7 +1665,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 解析请求以获取 image_size(用于图片计费)
imageSize := s.extractImageSize(body)
......@@ -1869,11 +1734,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if s.cache != nil {
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
}
// 执行带重试的请求
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
......@@ -1883,7 +1743,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
accessToken: accessToken,
action: upstreamAction,
body: wrappedBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
......@@ -1957,7 +1816,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps)
......@@ -2004,6 +1863,7 @@ handleSuccess:
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if stream {
// 客户端要求流式,直接透传
......@@ -2014,6 +1874,7 @@ handleSuccess:
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
......@@ -2043,6 +1904,7 @@ handleSuccess:
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
......@@ -2253,9 +2115,9 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
}
// retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s
if info.RetryDelay >= antigravityRateLimitThreshold {
return false, true, 0, info.ModelName
return false, true, info.RetryDelay, info.ModelName
}
// retryDelay < 阈值:智能重试
......@@ -2377,10 +2239,10 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte
func (s *AntigravityGatewayService) handleUpstreamError(
ctx context.Context, prefix string, account *Account,
statusCode int, headers http.Header, body []byte,
quotaScope AntigravityQuotaScope,
requestedModel string,
groupID int64, sessionHash string, isStickySession bool,
) *handleModelRateLimitResult {
// 模型级限流处理(在原有逻辑之前
// 模型级限流处理(优先
result := s.handleModelRateLimit(&handleModelRateLimitParams{
ctx: ctx,
prefix: prefix,
......@@ -2402,53 +2264,36 @@ func (s *AntigravityGatewayService) handleUpstreamError(
return nil
}
// ========== 原有逻辑,保持不变 ==========
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
// 429:尝试解析模型级限流,解析失败时兜底为账号级限流
if statusCode == 429 {
// 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。
if logBody, maxBytes := s.getLogConfig(); logBody {
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
}
useScopeLimit := quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 解析失败:使用默认限流时间(与临时限流保持一致)
// 可通过配置或环境变量覆盖
defaultDur := antigravityDefaultRateLimitDuration
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute
}
// 秒级环境变量优先级最高
if override, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = override
}
ra := time.Now().Add(defaultDur)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
defaultDur := s.getDefaultRateLimitDuration()
// 尝试解析模型 key 并设置模型级限流
modelKey := resolveAntigravityModelKey(requestedModel)
if modelKey != "" {
ra := s.resolveResetTime(resetAt, defaultDur)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err)
} else {
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v",
prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra)
}
return nil
}
resetTime := time.Unix(*resetAt, 0)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
// 无法解析模型 key,兜底为账号级限流
ra := s.resolveResetTime(resetAt, defaultDur)
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)",
prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return nil
}
// 其他错误码继续使用 rateLimitService
......@@ -2462,11 +2307,92 @@ func (s *AntigravityGatewayService) handleUpstreamError(
return nil
}
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
// getDefaultRateLimitDuration 获取默认限流时间
func (s *AntigravityGatewayService) getDefaultRateLimitDuration() time.Duration {
defaultDur := antigravityDefaultRateLimitDuration
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute
}
if override, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = override
}
return defaultDur
}
// resolveResetTime 根据解析的重置时间或默认时长计算重置时间点
func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur time.Duration) time.Time {
if resetAt != nil {
return time.Unix(*resetAt, 0)
}
return time.Now().Add(defaultDur)
}
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。
type antigravityClientWriter struct {
w gin.ResponseWriter
flusher http.Flusher
disconnected bool
prefix string // 日志前缀,标识来源方法
}
func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter {
return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix}
}
// Write 写入数据到客户端,写入失败时标记断开并返回 false
func (cw *antigravityClientWriter) Write(p []byte) bool {
if cw.disconnected {
return false
}
if _, err := cw.w.Write(p); err != nil {
cw.markDisconnected()
return false
}
cw.flusher.Flush()
return true
}
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool {
if cw.disconnected {
return false
}
if _, err := fmt.Fprintf(cw.w, format, args...); err != nil {
cw.markDisconnected()
return false
}
cw.flusher.Flush()
return true
}
func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected }
func (cw *antigravityClientWriter) markDisconnected() {
cw.disconnected = true
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
}
// handleStreamReadError 处理上游读取错误的通用逻辑。
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
return true, true
}
if clientDisconnected {
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
return true, true
}
return false, false
}
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
......@@ -2542,10 +2468,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
intervalCh = intervalTicker.C
}
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
if errorEventSent || cw.Disconnected() {
return
}
errorEventSent = true
......@@ -2557,9 +2485,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
......@@ -2574,11 +2505,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
cw.Fprintf("%s\n", line)
continue
}
......@@ -2614,27 +2541,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs = &ms
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
cw.Fprintf("data: %s\n\n", payload)
continue
}
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
cw.Fprintf("%s\n", line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout (antigravity)")
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
......@@ -3338,10 +3260,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
intervalCh = intervalTicker.C
}
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
if errorEventSent || cw.Disconnected() {
return
}
errorEventSent = true
......@@ -3349,19 +3273,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
flusher.Flush()
}
// finishUsage 是获取 processor 最终 usage 的辅助函数
finishUsage := func() *ClaudeUsage {
_, agUsage := processor.Finish()
return convertUsage(agUsage)
}
for {
select {
case ev, ok := <-events:
if !ok {
// 发送结束事件
// 上游完成,发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
cw.Write(finalEvents)
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
......@@ -3371,25 +3303,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
if len(claudeEvents) > 0 {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil {
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
}
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
cw.Write(claudeEvents)
}
case <-intervalCh:
......@@ -3397,13 +3318,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if time.Since(lastRead) < streamInterval {
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout (antigravity)")
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
......@@ -3542,3 +3465,288 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
payload["contents"] = filtered
return json.Marshal(payload)
}
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", false)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
streamRes := s.streamUpstreamResponse(c, resp, startTime)
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
usage := &ClaudeUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
for {
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
}
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
line := ev.line
// 记录首 token 时间
if firstTokenMs == nil && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
s.extractSSEUsage(line, usage)
// 透传行
cw.Fprintf("%s\n", line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
}
log.Printf("Stream data interval timeout (antigravity upstream)")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
}
}
// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景)
func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) {
if !strings.HasPrefix(line, "data: ") {
return
}
dataStr := strings.TrimPrefix(line, "data: ")
var event map[string]any
if json.Unmarshal([]byte(dataStr), &event) != nil {
return
}
u, ok := event["usage"].(map[string]any)
if !ok {
return
}
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}
......@@ -4,18 +4,42 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type antigravityFailingWriter struct {
gin.ResponseWriter
failAfter int // 允许成功写入的次数,之后所有写入返回错误
writes int
}
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
return &AntigravityGatewayService{
settingService: &SettingService{cfg: cfg},
}
}
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
......@@ -338,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
......@@ -393,10 +417,16 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
......@@ -404,25 +434,458 @@ func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n"))
_, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n"))
fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`)
fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`)
}()
svc := &AntigravityGatewayService{}
start := time.Now().Add(-10 * time.Millisecond)
usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start)
result := svc.streamUpstreamResponse(c, resp, start)
_ = pr.Close()
require.NotNil(t, usage)
require.Equal(t, 1, usage.InputTokens)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
// 第二次事件覆盖 output_tokens
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
require.Equal(t, 4, usage.CacheCreationInputTokens)
require.Equal(t, 5, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
require.Equal(t, 4, result.usage.CacheCreationInputTokens)
require.NotNil(t, result.firstTokenMs)
if firstTokenMs == nil {
t.Fatalf("expected firstTokenMs to be set")
}
// 确保有透传输出
require.True(t, strings.Contains(writer.Body.String(), "data:"))
require.Contains(t, rec.Body.String(), "data:")
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: content_block_delta`)
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, "content_block_delta")
require.Contains(t, body, "message_delta")
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 第一个 chunk(部分内容)
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
fmt.Fprintln(pw, "")
// 第二个 chunk(最终内容+完整 usage)
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require.Equal(t, 8, result.usage.InputTokens)
require.Equal(t, 8, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.CacheReadInputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "Hello")
require.Contains(t, body, "world")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require.Equal(t, 5, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证输出是 Claude SSE 格式(processor 会转换)
body := rec.Body.String()
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotNil(t, result.usage)
require.Equal(t, 20, result.usage.OutputTokens)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证:context 取消时返回 usage 且标记 clientDisconnect
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
fmt.Fprintln(pw, "")
// 不关闭 pw → 等待超时
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "write_failed")
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证:context 取消时不注入错误事件
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证:context 取消时不注入错误事件
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func TestExtractSSEUsage(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
line string
expected ClaudeUsage
}{
{
name: "message_delta with output_tokens",
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
expected: ClaudeUsage{OutputTokens: 42},
},
{
name: "non-data line ignored",
line: `event: message_start`,
expected: ClaudeUsage{},
},
{
name: "top-level usage with all fields",
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := &ClaudeUsage{}
svc.extractSSEUsage(tt.line, usage)
require.Equal(t, tt.expected, *usage)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func TestAntigravityClientWriter(t *testing.T) {
t.Run("normal write succeeds", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
ok := cw.Write([]byte("hello"))
require.True(t, ok)
require.False(t, cw.Disconnected())
require.Contains(t, rec.Body.String(), "hello")
})
t.Run("write failure marks disconnected", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
ok := cw.Write([]byte("hello"))
require.False(t, ok)
require.True(t, cw.Disconnected())
})
t.Run("subsequent writes are no-op", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
cw.Write([]byte("first"))
ok := cw.Fprintf("second %d", 2)
require.False(t, ok)
require.True(t, cw.Disconnected())
})
}
......@@ -2,63 +2,23 @@ package service
import (
"context"
"slices"
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
if len(supportedScopes) == 0 {
// 未配置时默认全部支持
return true
}
supported := slices.Contains(supportedScopes, string(scope))
return supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
return resolveAntigravityQuotaScope(requestedModel)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
// 返回空字符串表示无法解析
func resolveAntigravityModelKey(requestedModel string) string {
return normalizeAntigravityModelName(requestedModel)
}
// IsSchedulableForModel 结合模型级限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
......@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
var antigravityAllScopes = []AntigravityQuotaScope{
AntigravityQuotaScopeClaude,
AntigravityQuotaScopeGeminiText,
AntigravityQuotaScopeGeminiImage,
}
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
if a == nil || a.Platform != PlatformAntigravity {
return nil
}
now := time.Now()
result := make(map[string]int64)
for _, scope := range antigravityAllScopes {
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt != nil && now.Before(*resetAt) {
remainingSec := int64(time.Until(*resetAt).Seconds())
if remainingSec > 0 {
result[string(scope)] = remainingSec
}
}
}
if len(result) == 0 {
return nil
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
}
......@@ -65,12 +65,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
......@@ -84,16 +78,10 @@ type modelRateLimitCall struct {
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
......@@ -137,10 +125,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true
return nil
},
......@@ -161,23 +148,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
// 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
......@@ -195,7 +165,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
......@@ -206,22 +176,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED),
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
......@@ -241,7 +211,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
......@@ -269,12 +239,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
......@@ -287,12 +256,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
// 503 + 空响应体 → 不做任何处理
body := []byte(`{}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls)
require.Empty(t, repo.rateCalls)
}
......@@ -313,15 +281,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
......@@ -641,6 +601,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 7 * time.Second,
modelName: "gemini-pro",
},
{
......@@ -658,6 +619,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 39 * time.Second,
modelName: "gemini-3-pro-high",
},
{
......@@ -675,6 +637,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "gemini-2.5-flash",
},
{
......@@ -692,6 +655,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "claude-sonnet-4-5",
},
}
......@@ -710,6 +674,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if shouldRateLimit && tt.minWait > 0 {
if wait < tt.minWait {
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
......@@ -809,7 +778,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
}
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 1,
......@@ -821,19 +790,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
......@@ -842,17 +807,21 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
}
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 2,
......@@ -881,7 +850,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
......
......@@ -13,6 +13,23 @@ import (
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
// 仅关注 DeleteSessionAccountID 的调用记录
type stubSmartRetryCache struct {
GatewayCache // 嵌入接口,未实现的方法 panic(确保只调用预期方法)
deleteCalls []deleteSessionCall
}
type deleteSessionCall struct {
groupID int64
sessionHash string
}
func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error {
c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash})
return nil
}
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
......@@ -58,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -110,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -177,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
// 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次)
// 智能重试后仍然返回 429(需要提供 1 个响应,因为智能重试最多 1 次)
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
......@@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp2 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp3 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp1, failResp2, failResp3},
errors: []error{nil, nil, nil},
responses: []*http.Response{failResp1},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
......@@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Platform: PlatformAntigravity,
}
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
// 3s < 7s 阈值,应该触发智能重试(最多 1 次)
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
......@@ -262,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -284,7 +291,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
......@@ -324,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -380,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -429,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -480,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -541,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
......@@ -556,19 +563,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
require.True(t, switchErr.IsStickySession)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
// 第一次网络错误,第二次成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时(maxAttempts=1)直接耗尽重试并切换账号
func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
// 唯一一次重试遇到网络错误(nil response)
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误)
errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发
responses: []*http.Response{nil}, // 返回 nil(模拟网络错误)
errors: []error{nil}, // mock 不返回 error,靠 nil response 触发
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 8,
Name: "acc-8",
......@@ -600,7 +603,8 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -612,10 +616,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 2, "should have made two retry calls")
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.Len(t, upstream.calls, 1, "should have made one retry call")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
......@@ -653,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
......@@ -674,3 +683,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// ---------------------------------------------------------------------------
// 以下测试覆盖本次改动:
// 1. antigravitySmartRetryMaxAttempts = 1(仅重试 1 次)
// 2. 智能重试失败后清除粘性会话绑定(DeleteSessionAccountID)
// ---------------------------------------------------------------------------
// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1
func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) {
require.Equal(t, 1, antigravitySmartRetryMaxAttempts,
"antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting")
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession
// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 10,
Name: "acc-10",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-abc",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
// 验证返回 switchError
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
// 核心断言:DeleteSessionAccountID 被调用,且参数正确
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once")
require.Equal(t, int64(42), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash)
// 验证仅重试 1 次
require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession
// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountID(sessionHash 为空)
func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 11,
Name: "acc-11",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
groupID: 42,
sessionHash: "", // 非粘性会话,sessionHash 为空
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.False(t, result.switchError.IsStickySession)
// 核心断言:sessionHash 为空时不应调用 DeleteSessionAccountID
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty")
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic
// 边界:cache 为 nil 时不应 panic
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 12,
Name: "acc-12",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-nil-cache",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
// cache 为 nil,不应 panic
svc := &AntigravityGatewayService{cache: nil}
require.NotPanics(t, func() {
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
})
}
// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession
// 重试成功时不应清除粘性会话(只有失败才清除)
func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 13,
Name: "acc-13",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-success",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
// 核心断言:重试成功时不应清除粘性会话
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry")
}
// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry
// 长延迟路径(情况1)在 handleSmartRetry 中不直接调用 DeleteSessionAccountID
// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理)
func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 14,
Name: "acc-14",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 15s >= 7s 阈值 → 走长延迟路径
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-long-delay",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID
// (由上游 handler 的 shouldClearStickySession 处理)
require.Len(t, cache.deleteCalls, 0,
"long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)")
}
// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession
// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定
func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) {
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil}, // 网络错误
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 15,
Name: "acc-15",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 99,
sessionHash: "sticky-net-error",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 核心断言:网络错误耗尽重试后也应清除粘性绑定
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry")
require.Equal(t, int64(99), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash)
}
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 16,
Name: "acc-16",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 77,
sessionHash: "sticky-503-short",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 验证粘性绑定被清除
require.Len(t, cache.deleteCalls, 1)
require.Equal(t, int64(77), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey)
}
// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates
// 集成测试:antigravityRetryLoop → handleSmartRetry → switchError 传播
// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除
func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) {
// 初始 429 响应
initialRespBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
initialResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(initialRespBody)),
}
// 智能重试也返回 429
retryRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
retryResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(retryRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{initialResp, retryResp},
errors: []error{nil, nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 17,
Name: "acc-17",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{cache: cache}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 55,
sessionHash: "sticky-loop-test",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result when switchError")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop")
// 验证粘性绑定被清除
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
}
package service
import (
"testing"
)
func TestBuildSelectedSet(t *testing.T) {
tests := []struct {
name string
ids []string
wantNil bool
wantSize int
}{
{
name: "nil input returns nil (backward compatible: create all)",
ids: nil,
wantNil: true,
},
{
name: "empty slice returns empty map (create none)",
ids: []string{},
wantNil: false,
wantSize: 0,
},
{
name: "single ID",
ids: []string{"abc-123"},
wantNil: false,
wantSize: 1,
},
{
name: "multiple IDs",
ids: []string{"a", "b", "c"},
wantNil: false,
wantSize: 3,
},
{
name: "duplicate IDs are deduplicated",
ids: []string{"a", "a", "b"},
wantNil: false,
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildSelectedSet(tt.ids)
if tt.wantNil {
if got != nil {
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
}
return
}
if got == nil {
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
}
if len(got) != tt.wantSize {
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
}
// Verify all unique IDs are present
for _, id := range tt.ids {
if _, ok := got[id]; !ok {
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
}
}
})
}
}
func TestShouldCreateAccount(t *testing.T) {
tests := []struct {
name string
crsID string
selectedSet map[string]struct{}
want bool
}{
{
name: "nil set allows all (backward compatible)",
crsID: "any-id",
selectedSet: nil,
want: true,
},
{
name: "empty set blocks all",
crsID: "any-id",
selectedSet: map[string]struct{}{},
want: false,
},
{
name: "ID in set is allowed",
crsID: "abc-123",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: true,
},
{
name: "ID not in set is blocked",
crsID: "xyz-789",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
if got != tt.want {
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
tt.crsID, tt.selectedSet, got, tt.want)
}
})
}
}
......@@ -49,6 +49,7 @@ type SyncFromCRSInput struct {
Username string
Password string
SyncProxies bool
SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs
}
type SyncFromCRSItemResult struct {
......@@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct {
Extra map[string]any `json:"extra"`
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
// fetchCRSExport validates the connection parameters, authenticates with CRS,
// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS.
func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) {
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL := strings.TrimSpace(input.BaseURL)
normalizedURL := strings.TrimSpace(baseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
normalizedURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
baseURL = normalized
normalizedURL = normalized
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" {
return nil, errors.New("username and password are required")
}
......@@ -221,12 +224,16 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
if err != nil {
return nil, err
}
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
return crsExportAccounts(ctx, client, normalizedURL, adminToken)
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
......@@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
),
}
selectedSet := buildSelectedSet(input.SelectedAccountIDs)
var proxies []Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
......@@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
......@@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
......@@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
......@@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
......@@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
......@@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
......@@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
return newCredentials
}
// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup.
// Returns nil if ids is nil (field not sent → backward compatible: create all).
// Returns an empty map if ids is non-nil but empty (user selected none → create none).
func buildSelectedSet(ids []string) map[string]struct{} {
if ids == nil {
return nil
}
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
// shouldCreateAccount checks if a new CRS account should be created based on user selection.
// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set.
func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool {
if selectedSet == nil {
return true
}
_, ok := selectedSet[crsID]
return ok
}
// PreviewFromCRSResult contains the preview of accounts from CRS before sync.
type PreviewFromCRSResult struct {
NewAccounts []CRSPreviewAccount `json:"new_accounts"`
ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"`
}
// CRSPreviewAccount represents a single account in the preview result.
type CRSPreviewAccount struct {
CRSAccountID string `json:"crs_account_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
}
// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them
// as new or existing by batch-querying local crs_account_id mappings.
func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) {
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
// Batch query all existing CRS account IDs
existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err)
}
result := &PreviewFromCRSResult{
NewAccounts: make([]CRSPreviewAccount, 0),
ExistingAccounts: make([]CRSPreviewAccount, 0),
}
classify := func(crsID, kind, name, platform, accountType string) {
preview := CRSPreviewAccount{
CRSAccountID: crsID,
Kind: kind,
Name: defaultName(name, crsID),
Platform: platform,
Type: accountType,
}
if _, exists := existingCRSIDs[crsID]; exists {
result.ExistingAccounts = append(result.ExistingAccounts, preview)
} else {
result.NewAccounts = append(result.NewAccounts, preview)
}
}
for _, src := range exported.Data.ClaudeAccounts {
authType := strings.TrimSpace(src.AuthType)
if authType == "" {
authType = AccountTypeOAuth
}
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType)
}
for _, src := range exported.Data.ClaudeConsoleAccounts {
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey)
}
for _, src := range exported.Data.OpenAIOAuthAccounts {
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth)
}
for _, src := range exported.Data.OpenAIResponsesAccounts {
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey)
}
for _, src := range exported.Data.GeminiOAuthAccounts {
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth)
}
for _, src := range exported.Data.GeminiAPIKeyAccounts {
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey)
}
return result, nil
}
package service
import (
"strconv"
"strings"
"time"
gocache "github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const digestSessionTTL = 5 * time.Minute
// sessionEntry flat cache 条目
type sessionEntry struct {
uuid string
accountID int64
}
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type DigestSessionStore struct {
cache *gocache.Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func NewDigestSessionStore() *DigestSessionStore {
return &DigestSessionStore{
cache: gocache.New(digestSessionTTL, time.Minute),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
if digestChain == "" {
return
}
ns := buildNS(groupID, prefixHash)
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.cache.Delete(ns + oldDigestChain)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" {
return "", 0, "", false
}
ns := buildNS(groupID, prefixHash)
chain := digestChain
for {
if val, ok := s.cache.Get(ns + chain); ok {
if e, ok := val.(*sessionEntry); ok {
return e.uuid, e.accountID, chain, true
}
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", 0, "", false
}
chain = chain[:i]
}
}
// buildNS 构建 namespace 前缀
func buildNS(groupID int64, prefixHash string) string {
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
}
//go:build unit
package service
import (
"fmt"
"sync"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
// 保存短链
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
// 用长链查找,应前缀匹配到短链
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-short", uuid)
assert.Equal(t, int64(10), accountID)
assert.Equal(t, "u:a-m:b", matchedChain)
}
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "uuid-3", uuid)
assert.Equal(t, int64(3), accountID)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
store := NewDigestSessionStore()
// 第一轮:保存 "u:a-m:b"
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
// 旧链 "u:a-m:b" 应已被删除
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found, "old chain should be deleted")
// 新链应能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
store := NewDigestSessionStore()
// 相同系统提示词,不同用户提示词
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(200), accountID)
}
func TestDigestSessionStore_NoMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 完全不同的 chain
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
// 不同 prefixHash 应隔离
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 不同 groupID 应隔离
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
store := NewDigestSessionStore()
// 空链不应保存
store.Save(1, "prefix", "", "uuid-1", 100, "")
_, _, _, found := store.Find(1, "prefix", "")
assert.False(t, found)
}
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 立即应该能找到
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
// 等待过期 + 清理周期
time.Sleep(300 * time.Millisecond)
// 过期后应找不到
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
store := NewDigestSessionStore()
var wg sync.WaitGroup
const goroutines = 50
const operations = 100
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
prefix := fmt.Sprintf("prefix-%d", id%5)
for i := 0; i < operations; i++ {
chain := fmt.Sprintf("u:%d-m:%d", id, i)
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
store.Save(1, prefix, chain, uuid, int64(id), "")
store.Find(1, prefix, chain)
}
}(g)
}
wg.Wait()
}
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
store := NewDigestSessionStore()
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
require.True(t, found, "should find session: %s", sess.chain)
assert.Equal(t, sess.uuid, uuid)
assert.Equal(t, sess.accountID, accountID)
}
// 验证继续对话的场景
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
store := NewDigestSessionStore()
// 插入 1000 个会话
for i := 0; i < 1000; i++ {
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
// 查找性能测试
start := time.Now()
const lookups = 10000
for i := 0; i < lookups; i++ {
idx := i % 1000
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
_, _, _, found := store.Find(1, "prefix", chain)
assert.True(t, found)
}
elapsed := time.Since(start)
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
}
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
// 精确匹配
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
// 前缀匹配(截断后命中)
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
}
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
store := NewDigestSessionStore()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for conv := 0; conv < 100; conv++ {
var prevMatchedChain string
for round := 0; round < 10; round++ {
chain := fmt.Sprintf("s:sys-u:user%d", conv)
for r := 0; r < round; r++ {
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
}
uuid := fmt.Sprintf("uuid-conv%d", conv)
_, _, matched, _ := store.Find(1, "prefix", chain)
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
prevMatchedChain = matched
_ = prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount := store.cache.ItemCount()
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
}
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
// 使用极短 TTL 验证大量写入后 cache 能被清理
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
for i := 0; i < 500; i++ {
chain := fmt.Sprintf("u:user%d", i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
assert.Equal(t, 500, store.cache.ItemCount())
// 等待 TTL + 清理周期
time.Sleep(300 * time.Millisecond)
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
}
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
store := NewDigestSessionStore()
// 保存 chain
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
// 仍然能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mocks (scoped to this file by naming convention)
// ---------------------------------------------------------------------------
// epFixedUpstream returns a fixed response for every request.
type epFixedUpstream struct {
statusCode int
body string
calls int
}
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.calls++
return &http.Response{
StatusCode: u.statusCode,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(u.body)),
}, nil
}
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
// epAccountRepo records SetTempUnschedulable / SetError calls.
type epAccountRepo struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
}
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrCalls++
return nil
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func saveAndSetBaseURLs(t *testing.T) {
t.Helper()
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvail := antigravity.DefaultURLAvailability
antigravity.BaseURLs = []string{"https://ep-test.example"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
t.Cleanup(func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvail
})
}
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
return antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[ep-test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: handleError,
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
tests := []struct {
name string
upstreamStatus int
upstreamBody string
customCodes []any
expectHandleError int
expectUpstream int
expectStatusCode int
}{
{
name: "429_in_custom_codes_matched",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(429)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "429_not_in_custom_codes_skipped",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "500_in_custom_codes_matched",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(500)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 500,
},
{
name: "500_not_in_custom_codes_skipped",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(429)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": tt.customCodes,
},
}
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
})
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_TempUnschedulable
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
tempRulesAccount := func(rules []any) *Account {
return &Account{
ID: 200,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": rules,
},
}
}
overloadedRule := map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
}
rateLimitRule := map[string]any{
"error_code": float64(429),
"keywords": []any{"rate limited keyword"},
"duration_minutes": float64(5),
}
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{rateLimitRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
// Use a short-lived context: the backoff sleep (~1s) will be
// interrupted, proving the code entered the default retry path
// instead of breaking early via error policy.
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
result, err := svc.antigravityRetryLoop(p)
// Context cancellation during backoff proves default retry was entered
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
})
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NilRateLimitService
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
// rateLimitService is nil — must not panic
svc := &AntigravityGatewayService{rateLimitService: nil}
account := &Account{
ID: 300,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
// Should not panic; enters the default retry path (eventually times out)
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1)
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
// Plain OAuth account with no error policy configured
account := &Account{
ID: 400,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "no_policy_oauth_returns_none",
account: &Account{
ID: 1,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
// no custom error codes, no temp rules
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_hit_returns_matched",
account: &Account{
ID: 2,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyMatched,
},
{
name: "custom_error_codes_miss_returns_skipped",
account: &Account{
ID: 3,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 503,
body: []byte(`"error"`),
expected: ErrorPolicySkipped,
},
{
name: "temp_unschedulable_hit_returns_temp_unscheduled",
account: &Account{
ID: 4,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{
ID: 5,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`random msg`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_override_temp_unschedulable",
account: &Account{
ID: 6,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
})
}
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------
func TestApplyErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expectedHandled bool
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
{
name: "none_not_handled",
account: &Account{
ID: 10,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
handleErrorCalls: 0,
},
{
name: "skipped_handled_no_handleError",
account: &Account{
ID: 11,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 0,
},
{
name: "matched_handled_calls_handleError",
account: &Account{
ID: 12,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 1,
},
{
name: "temp_unscheduled_returns_switch_error",
account: &Account{
ID: 13,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedSwitchErr: true,
handleErrorCalls: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{
rateLimitService: rlSvc,
}
var handleErrorCount int
p := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: tt.account,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
},
isStickySession: true,
}
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, retErr, &switchErr)
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
} else {
require.NoError(t, retErr)
}
})
}
}
// ---------------------------------------------------------------------------
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
// ---------------------------------------------------------------------------
type errorPolicyRepoStub struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
lastErrorMsg string
}
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrCalls++
r.lastErrorMsg = errorMsg
return nil
}
......@@ -77,7 +77,12 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) {
func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
......@@ -145,9 +150,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
......@@ -219,22 +221,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil
}
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
......@@ -293,6 +279,10 @@ func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, g
return nil, nil
}
func (m *mockGroupRepoForGateway) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
func ptr[T any](v T) *T {
return &v
}
......
......@@ -6,9 +6,19 @@ import (
"fmt"
"math"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
......@@ -31,11 +41,13 @@ type ParsedRequest struct {
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
......@@ -64,6 +76,20 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID
}
}
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
}
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
......@@ -73,6 +99,7 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
......
......@@ -4,12 +4,13 @@ import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
......@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
......@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 1, parsed.MaxTokens)
}
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 0, parsed.MaxTokens)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
......@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
// ============ Gemini 原生格式解析测试 ============
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
body := []byte(`{
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
require.Nil(t, parsed.System, "no systemInstruction means nil System")
}
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
body := []byte(`{
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
parts, ok := parsed.System.([]any)
require.True(t, ok)
require.Len(t, parts, 1)
partMap, ok := parts[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "You are a helpful assistant.", partMap["text"])
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
body := []byte(`{
"model": "gemini-2.5-pro",
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Equal(t, "gemini-2.5-pro", parsed.Model)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
// Gemini 格式下 system/messages 字段应被忽略
body := []byte(`{
"system": "should be ignored",
"messages": [{"role": "user", "content": "ignored"}],
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
require.Nil(t, parsed.System, "no systemInstruction = nil System")
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
}
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
body := []byte(`{"contents": []}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Empty(t, parsed.Messages)
}
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
body := []byte(`{"model": "gemini-2.5-flash"}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Nil(t, parsed.Messages)
require.Equal(t, "gemini-2.5-flash", parsed.Model)
}
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
body := []byte(`{
"system": "real system",
"messages": [{"role": "user", "content": "real content"}],
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
"systemInstruction": {"parts": [{"text": "ignored"}]}
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
require.NoError(t, err)
require.True(t, parsed.HasSystem)
require.Equal(t, "real system", parsed.System)
require.Len(t, parsed.Messages, 1)
msg, ok := parsed.Messages[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "real content", msg["content"])
}
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any
......
......@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
......@@ -17,6 +16,7 @@ import (
"os"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
......@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
......@@ -245,9 +246,6 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{
"accept": true,
......@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
......@@ -295,24 +286,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息(uuid, accountID)
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
......@@ -323,21 +296,15 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 或请求的模型处于限流状态时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// within temporary unschedulable period, or the requested model is rate-limited.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
......@@ -349,8 +316,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
// 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
return false
......@@ -417,6 +384,7 @@ type GatewayService struct {
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
......@@ -450,6 +418,7 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
......@@ -459,6 +428,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
......@@ -492,23 +462,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用 system 内容
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
_, _ = combined.WriteString(systemText)
}
}
for _, msg := range parsed.Messages {
if m, ok := msg.(map[string]any); ok {
if content, exists := m["content"]; exists {
// Anthropic: messages[].content
if msgText := s.extractTextFromContent(content); msgText != "" {
_, _ = combined.WriteString(msgText)
}
} else if parts, ok := m["parts"].([]any); ok {
// Gemini: contents[].parts[].text
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
_, _ = combined.WriteString(text)
}
}
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
}
}
}
}
if combined.Len() > 0 {
return s.hashContent(combined.String())
}
return ""
}
......@@ -536,19 +528,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息(uuid, accountID)
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
......@@ -633,8 +643,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
h := xxhash.Sum64String(content)
return strconv.FormatUint(h, 36)
}
// replaceModelInBody 替换请求体中的model字段
......@@ -993,13 +1003,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
......@@ -1114,7 +1117,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
......@@ -1194,6 +1196,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(routingAvailable)
// 4. 尝试获取槽位
for _, item := range routingAvailable {
......@@ -1268,7 +1271,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
......@@ -1348,10 +1350,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
......@@ -1366,71 +1364,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
// 分层过滤选择:优先级 → 负载率 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
......@@ -1470,7 +1404,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
available = newAvailable
}
}
}
// ============ Layer 3: 兜底排队 ============
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
......@@ -2004,87 +1937,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
shuffleWithinPriorityAndLastUsed(accounts)
}
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
func shuffleWithinSortGroups(accounts []accountWithLoad) {
if len(accounts) <= 1 {
return
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
j++
}
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
i = j
}
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
if a.account.Priority != b.account.Priority {
return false
}
return info.CallCount
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return false
}
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
if len(accounts) <= 1 {
return
}
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
j++
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
i = j
}
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
func sameAccountGroup(a, b *Account) bool {
if a.Priority != b.Priority {
return false
}
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
func sameLastUsedAt(a, b *time.Time) bool {
switch {
case a == nil && b == nil:
return true
case a == nil || b == nil:
return false
default:
return a.Unix() == b.Unix()
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
}
// sortCandidatesForFallback 根据配置选择排序策略
......@@ -2139,13 +2064,6 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
......@@ -2173,9 +2091,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
......@@ -2276,9 +2191,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
......@@ -2387,9 +2299,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
......@@ -2492,9 +2401,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
......@@ -5185,27 +5091,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope,跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
......
......@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}
......
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
statusCode int
expected bool
}{
{"401_failover", 401, true},
{"403_failover", 403, true},
{"429_failover", 429, true},
{"529_failover", 529, true},
{"500_failover", 500, true},
{"502_failover", 502, true},
{"503_failover", 503, true},
{"400_no_failover", 400, false},
{"404_no_failover", 404, false},
{"422_no_failover", 422, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "gemini_apikey_custom_codes_hit",
account: &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 429,
body: []byte(`{"error":"rate limited"}`),
expected: ErrorPolicyMatched,
},
{
name: "gemini_apikey_custom_codes_miss",
account: &Account{
ID: 101,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicySkipped,
},
{
name: "gemini_apikey_no_custom_codes_returns_none",
account: &Account{
ID: 102,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicyNone,
},
{
name: "gemini_apikey_temp_unschedulable_hit",
account: &Account{
ID: 103,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{
ID: 104,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicyIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
account *Account
statusCode int
respBody []byte
expectFailover bool // expect UpstreamFailoverError
expectHandleError bool // expect handleGeminiUpstreamError to be called
expectShouldFailover bool // for None path, whether shouldFailover triggers
}{
{
name: "custom_codes_matched_429_failover",
account: &Account{
ID: 200,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
},
{
name: "custom_codes_skipped_500_no_failover",
account: &Account{
ID: 201,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
respBody: []byte(`{"error":"internal"}`),
expectFailover: false,
expectHandleError: false,
},
{
name: "temp_unschedulable_matched_failover",
account: &Account{
ID: 202,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
respBody: []byte(`overloaded`),
expectFailover: true,
expectHandleError: true,
},
{
name: "no_policy_429_failover_via_shouldFailover",
account: &Account{
ID: 203,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
expectShouldFailover: true,
},
{
name: "no_policy_400_no_failover",
account: &Account{
ID: 204,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 400,
respBody: []byte(`{"error":"bad request"}`),
expectFailover: false,
expectHandleError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &geminiErrorPolicyRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &GeminiMessagesCompatService{
accountRepo: repo,
rateLimitService: rlSvc,
}
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var handleErrorCalled bool
var gotFailover bool
ctx := context.Background()
statusCode := tt.statusCode
respBody := tt.respBody
account := tt.account
headers := http.Header{}
if svc.rateLimitService != nil {
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
case ErrorPolicySkipped:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover = false
handleErrorCalled = false
goto verify
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
gotFailover = true
goto verify
}
}
// ErrorPolicyNone → original logic
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
gotFailover = true
}
verify:
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
if tt.expectShouldFailover {
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
svc := &GeminiMessagesCompatService{
rateLimitService: nil,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx := context.Background()
account := &Account{
ID: 300,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if svc.rateLimitService != nil {
t.Fatal("rateLimitService should be nil for this test")
}
// shouldFailoverGeminiUpstreamError still works
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require.NotPanics(t, func() {
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type geminiErrorPolicyRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setRateLimitedCalls int
setTempCalls int
}
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrorCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
r.setRateLimitedCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.setTempCalls++
return nil
}
......@@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
......@@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
......@@ -837,12 +831,17 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
......@@ -869,6 +868,10 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
......@@ -1026,10 +1029,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
......@@ -1097,10 +1097,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
......@@ -1261,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
......@@ -1282,7 +1274,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
......@@ -1306,6 +1310,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
......@@ -2420,10 +2428,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment