Commit abf5de69 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'main' into test

parents 7582dc53 174d7c77
...@@ -155,6 +155,7 @@ type GeminiUsageMetadata struct { ...@@ -155,6 +155,7 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
} }
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search) // GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
......
...@@ -64,6 +64,10 @@ const MaxTokensBudgetPadding = 1000 ...@@ -64,6 +64,10 @@ const MaxTokensBudgetPadding = 1000
// Gemini 2.5 Flash thinking budget 上限 // Gemini 2.5 Flash thinking budget 上限
const Gemini25FlashThinkingBudgetLimit = 24576 const Gemini25FlashThinkingBudgetLimit = 24576
// 对于 Antigravity 的 Claude(budget-only)模型,该语义最终等价为 thinkingBudget=24576。
// 这里复用相同数值以保持行为一致。
const ClaudeAdaptiveHighThinkingBudgetTokens = Gemini25FlashThinkingBudgetLimit
// ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens // ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens
// Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens // Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens
// 返回调整后的 maxTokens 和是否进行了调整 // 返回调整后的 maxTokens 和是否进行了调整
...@@ -96,7 +100,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map ...@@ -96,7 +100,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
} }
// 检测是否启用 thinking // 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
...@@ -198,8 +202,7 @@ type modelInfo struct { ...@@ -198,8 +202,7 @@ type modelInfo struct {
// modelInfoMap 模型前缀 → 模型信息映射 // modelInfoMap 模型前缀 → 模型信息映射
// 只有在此映射表中的模型才会注入身份提示词 // 只有在此映射表中的模型才会注入身份提示词
// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking, // 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换
var modelInfoMap = map[string]modelInfo{ var modelInfoMap = map[string]modelInfo{
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
...@@ -593,6 +596,10 @@ func maxOutputTokensLimit(model string) int { ...@@ -593,6 +596,10 @@ func maxOutputTokensLimit(model string) int {
return maxOutputTokensUpperBound return maxOutputTokensUpperBound
} }
func isAntigravityOpus46Model(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6")
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
maxLimit := maxOutputTokensLimit(req.Model) maxLimit := maxOutputTokensLimit(req.Model)
config := &GeminiGenerationConfig{ config := &GeminiGenerationConfig{
...@@ -606,25 +613,36 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { ...@@ -606,25 +613,36 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
} }
// Thinking 配置 // Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" { if req.Thinking != nil && (req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive") {
config.ThinkingConfig = &GeminiThinkingConfig{ config.ThinkingConfig = &GeminiThinkingConfig{
IncludeThoughts: true, IncludeThoughts: true,
} }
// - thinking.type=enabled:budget_tokens>0 用显式预算
// - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576)
budget := -1
if req.Thinking.BudgetTokens > 0 { if req.Thinking.BudgetTokens > 0 {
budget := req.Thinking.BudgetTokens budget = req.Thinking.BudgetTokens
}
if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) {
budget = ClaudeAdaptiveHighThinkingBudgetTokens
}
// 正预算需要做上限与 max_tokens 约束;动态预算(-1)直接透传给上游。
if budget > 0 {
// gemini-2.5-flash 上限 // gemini-2.5-flash 上限
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit { if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit {
budget = Gemini25FlashThinkingBudgetLimit budget = Gemini25FlashThinkingBudgetLimit
} }
config.ThinkingConfig.ThinkingBudget = budget
// 自动修正:max_tokens 必须大于 budget_tokens // 自动修正:max_tokens 必须大于 budget_tokens(Claude 上游要求)
if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok { if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok {
log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)", log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)",
config.MaxOutputTokens, adjusted, budget) config.MaxOutputTokens, adjusted, budget)
config.MaxOutputTokens = adjusted config.MaxOutputTokens = adjusted
} }
} }
config.ThinkingConfig.ThinkingBudget = budget
} }
if config.MaxOutputTokens > maxLimit { if config.MaxOutputTokens > maxLimit {
......
...@@ -259,3 +259,93 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { ...@@ -259,3 +259,93 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
}) })
} }
} }
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
tests := []struct {
name string
model string
thinking *ThinkingConfig
wantBudget int
wantPresent bool
}{
{
name: "enabled without budget defaults to dynamic (-1)",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "enabled"},
wantBudget: -1,
wantPresent: true,
},
{
name: "enabled with budget uses the provided value",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1024},
wantBudget: 1024,
wantPresent: true,
},
{
name: "enabled with -1 budget uses dynamic (-1)",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: -1},
wantBudget: -1,
wantPresent: true,
},
{
name: "adaptive on opus4.6 maps to high budget (24576)",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "adaptive", BudgetTokens: 20000},
wantBudget: ClaudeAdaptiveHighThinkingBudgetTokens,
wantPresent: true,
},
{
name: "adaptive on non-opus model keeps default dynamic (-1)",
model: "claude-sonnet-4-5-thinking",
thinking: &ThinkingConfig{Type: "adaptive"},
wantBudget: -1,
wantPresent: true,
},
{
name: "disabled does not emit thinkingConfig",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024},
wantBudget: 0,
wantPresent: false,
},
{
name: "nil thinking does not emit thinkingConfig",
model: "claude-opus-4-6-thinking",
thinking: nil,
wantBudget: 0,
wantPresent: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &ClaudeRequest{
Model: tt.model,
Thinking: tt.thinking,
}
cfg := buildGenerationConfig(req)
if cfg == nil {
t.Fatalf("expected non-nil generationConfig")
}
if tt.wantPresent {
if cfg.ThinkingConfig == nil {
t.Fatalf("expected thinkingConfig to be present")
}
if !cfg.ThinkingConfig.IncludeThoughts {
t.Fatalf("expected includeThoughts=true")
}
if cfg.ThinkingConfig.ThinkingBudget != tt.wantBudget {
t.Fatalf("expected thinkingBudget=%d, got %d", tt.wantBudget, cfg.ThinkingConfig.ThinkingBudget)
}
return
}
if cfg.ThinkingConfig != nil {
t.Fatalf("expected thinkingConfig to be nil, got %+v", cfg.ThinkingConfig)
}
})
}
}
...@@ -282,7 +282,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon ...@@ -282,7 +282,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
if geminiResp.UsageMetadata != nil { if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount cached := geminiResp.UsageMetadata.CachedContentTokenCount
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
} }
......
...@@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { ...@@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
if geminiResp.UsageMetadata != nil { if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount cached := geminiResp.UsageMetadata.CachedContentTokenCount
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached p.cacheReadTokens = cached
} }
...@@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
if v1Resp.Response.UsageMetadata != nil { if v1Resp.Response.UsageMetadata != nil {
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
} }
......
...@@ -15,7 +15,6 @@ type captureState struct { ...@@ -15,7 +15,6 @@ type captureState struct {
} }
type capturedWrite struct { type capturedWrite struct {
entry zapcore.Entry
fields []zapcore.Field fields []zapcore.Field
} }
...@@ -51,7 +50,6 @@ func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { ...@@ -51,7 +50,6 @@ func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
allFields = append(allFields, c.withFields...) allFields = append(allFields, c.withFields...)
allFields = append(allFields, fields...) allFields = append(allFields, fields...)
c.state.writes = append(c.state.writes, capturedWrite{ c.state.writes = append(c.state.writes, capturedWrite{
entry: entry,
fields: allFields, fields: allFields,
}) })
return nil return nil
......
...@@ -448,7 +448,12 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -448,7 +448,12 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
q = q.Where(dbaccount.TypeEQ(accountType)) q = q.Where(dbaccount.TypeEQ(accountType))
} }
if status != "" { if status != "" {
q = q.Where(dbaccount.StatusEQ(status)) switch status {
case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
default:
q = q.Where(dbaccount.StatusEQ(status))
}
} }
if search != "" { if search != "" {
q = q.Where(dbaccount.NameContainsFold(search)) q = q.Where(dbaccount.NameContainsFold(search))
......
...@@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err ...@@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err
SetPriority(rule.Priority). SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode). SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode). SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody) SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
if len(rule.ErrorCodes) > 0 { if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes) builder.SetErrorCodes(rule.ErrorCodes)
...@@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err ...@@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err
SetPriority(rule.Priority). SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode). SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode). SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody) SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
// 处理可选字段 // 处理可选字段
if len(rule.ErrorCodes) > 0 { if len(rule.ErrorCodes) > 0 {
...@@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model ...@@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model
Platforms: e.Platforms, Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode, PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody, PassthroughBody: e.PassthroughBody,
SkipMonitoring: e.SkipMonitoring,
CreatedAt: e.CreatedAt, CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt, UpdatedAt: e.UpdatedAt,
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -106,7 +107,12 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -106,7 +107,12 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
q = q.Where(redeemcode.StatusEQ(status)) q = q.Where(redeemcode.StatusEQ(status))
} }
if search != "" { if search != "" {
q = q.Where(redeemcode.CodeContainsFold(search)) q = q.Where(
redeemcode.Or(
redeemcode.CodeContainsFold(search),
redeemcode.HasUserWith(user.EmailContainsFold(search)),
),
)
} }
total, err := q.Count(ctx) total, err := q.Count(ctx)
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
...@@ -191,6 +192,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -191,6 +192,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser.EmailContainsFold(filters.Search), dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search), dbuser.UsernameContainsFold(filters.Search),
dbuser.NotesContainsFold(filters.Search), dbuser.NotesContainsFold(filters.Search),
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
), ),
) )
} }
......
...@@ -290,6 +290,7 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) ...@@ -290,6 +290,7 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
{ {
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL) antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode) antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
antigravity.POST("/oauth/refresh-token", h.Admin.AntigravityOAuth.RefreshToken)
} }
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"log/slog" "log/slog"
mathrand "math/rand" mathrand "math/rand"
"net" "net"
...@@ -15,6 +16,7 @@ import ( ...@@ -15,6 +16,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -41,6 +43,12 @@ const ( ...@@ -41,6 +43,12 @@ const (
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待) antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
// 使用固定 1s 间隔重试,最多重试 60 次
antigravityModelCapacityRetryMaxAttempts = 60
antigravityModelCapacityRetryWait = 1 * time.Second
// Google RPC 状态和类型常量 // Google RPC 状态和类型常量
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED" googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
googleRPCStatusUnavailable = "UNAVAILABLE" googleRPCStatusUnavailable = "UNAVAILABLE"
...@@ -61,6 +69,9 @@ const ( ...@@ -61,6 +69,9 @@ const (
// 单账号 503 退避重试:原地重试的总累计等待时间上限 // 单账号 503 退避重试:原地重试的总累计等待时间上限
// 超过此上限将不再重试,直接返回 503 // 超过此上限将不再重试,直接返回 503
antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second
// MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间
antigravityModelCapacityCooldown = 10 * time.Second
) )
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
...@@ -69,8 +80,15 @@ var antigravityPassthroughErrorMessages = []string{ ...@@ -69,8 +80,15 @@ var antigravityPassthroughErrorMessages = []string{
"prompt is too long", "prompt is too long",
} }
// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试
var (
modelCapacityExhaustedMu sync.RWMutex
modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until
)
const ( const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
) )
...@@ -132,6 +150,20 @@ type antigravityRetryLoopResult struct { ...@@ -132,6 +150,20 @@ type antigravityRetryLoopResult struct {
resp *http.Response resp *http.Response
} }
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
func resolveAntigravityForwardBaseURL() string {
baseURLs := antigravity.ForwardBaseURLs()
if len(baseURLs) == 0 {
return ""
}
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
if mode == "prod" && len(baseURLs) > 1 {
return baseURLs[1]
}
return baseURLs[0]
}
// smartRetryAction 智能重试的处理结果 // smartRetryAction 智能重试的处理结果
type smartRetryAction int type smartRetryAction int
...@@ -159,7 +191,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -159,7 +191,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
// 判断是否触发智能重试 // 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody) shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// 情况1: retryDelay >= 阈值,限流模型并切换账号 // 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel { if shouldRateLimitModel {
...@@ -196,20 +228,48 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -196,20 +228,48 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
} }
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次) // 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED),智能重试
if shouldSmartRetry { if shouldSmartRetry {
var lastRetryResp *http.Response var lastRetryResp *http.Response
var lastRetryBody []byte var lastRetryBody []byte
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ { // MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", maxAttempts := antigravitySmartRetryMaxAttempts
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID) if isModelCapacityExhausted {
maxAttempts = antigravityModelCapacityRetryMaxAttempts
waitDuration = antigravityModelCapacityRetryWait
// 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503
if modelName != "" {
modelCapacityExhaustedMu.RLock()
cooldownUntil, exists := modelCapacityExhaustedUntil[modelName]
modelCapacityExhaustedMu.RUnlock()
if exists && time.Now().Before(cooldownUntil) {
log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)",
p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05"))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
},
}
}
}
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_smart_retry", p.prefix) timer.Stop()
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-time.After(waitDuration): case <-timer.C:
} }
// 智能重试:创建新请求 // 智能重试:创建新请求
...@@ -229,13 +289,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -229,13 +289,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts) log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
// 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown
if isModelCapacityExhausted && modelName != "" {
modelCapacityExhaustedMu.Lock()
delete(modelCapacityExhaustedUntil, modelName)
modelCapacityExhaustedMu.Unlock()
}
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
} }
// 网络错误时,继续重试 // 网络错误时,继续重试
if retryErr != nil || retryResp == nil { if retryErr != nil || retryResp == nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr) log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr)
continue continue
} }
...@@ -245,13 +311,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -245,13 +311,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
lastRetryResp = retryResp lastRetryResp = retryResp
if retryResp != nil { if retryResp != nil {
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
} }
// 解析新的重试信息,用于下次重试的等待时间 // 解析新的重试信息,用于下次重试的等待时间(MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过)
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil { if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil {
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newShouldRetry && newWaitDuration > 0 { if newShouldRetry && newWaitDuration > 0 {
waitDuration = newWaitDuration waitDuration = newWaitDuration
} }
...@@ -268,6 +334,27 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -268,6 +334,27 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryBody = respBody retryBody = respBody
} }
// MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义
// 直接返回上游错误响应,不设置模型限流,不切换账号
if isModelCapacityExhausted {
// 设置 cooldown,让后续请求快速失败,避免重复重试
if modelName != "" {
modelCapacityExhaustedMu.Lock()
modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown)
modelCapacityExhaustedMu.Unlock()
}
log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)",
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
},
}
}
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号, // 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。 // 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
...@@ -283,8 +370,8 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -283,8 +370,8 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
} }
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)", log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
resetAt := time.Now().Add(rateLimitDuration) resetAt := time.Now().Add(rateLimitDuration)
if p.accountRepo != nil && modelName != "" { if p.accountRepo != nil && modelName != "" {
...@@ -368,11 +455,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( ...@@ -368,11 +455,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
timer.Stop()
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix) logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-time.After(waitDuration): case <-timer.C:
} }
totalWaited += waitDuration totalWaited += waitDuration
...@@ -406,12 +495,12 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( ...@@ -406,12 +495,12 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
_ = lastRetryResp.Body.Close() _ = lastRetryResp.Body.Close()
} }
lastRetryResp = retryResp lastRetryResp = retryResp
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
// 解析新的重试信息,更新下次等待时间 // 解析新的重试信息,更新下次等待时间
if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil { if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil {
_, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) _, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newWaitDuration > 0 { if newWaitDuration > 0 {
waitDuration = newWaitDuration waitDuration = newWaitDuration
if waitDuration > antigravitySingleAccountSmartRetryMaxWait { if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
...@@ -467,10 +556,11 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP ...@@ -467,10 +556,11 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
} }
} }
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() baseURL := resolveAntigravityForwardBaseURL()
if len(availableURLs) == 0 { if baseURL == "" {
availableURLs = antigravity.BaseURLs return nil, errors.New("no antigravity forward base url configured")
} }
availableURLs := []string{baseURL}
var resp *http.Response var resp *http.Response
var usedBaseURL string var usedBaseURL string
...@@ -908,11 +998,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -908,11 +998,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// URL fallback 循环 baseURL := resolveAntigravityForwardBaseURL()
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() if baseURL == "" {
if len(availableURLs) == 0 { return nil, errors.New("no antigravity forward base url configured")
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
} }
availableURLs := []string{baseURL}
var lastErr error var lastErr error
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
...@@ -1217,7 +1307,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1217,7 +1307,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
} }
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
// 获取 access_token // 获取 access_token
...@@ -1373,7 +1463,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1373,7 +1463,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
break break
} }
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
if retryResp.StatusCode == http.StatusTooManyRequests { if retryResp.StatusCode == http.StatusTooManyRequests {
retryBaseURL := "" retryBaseURL := ""
...@@ -1454,6 +1544,27 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1454,6 +1544,27 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if isGoogleProjectConfigError(msg) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
upstreamDetail := s.getUpstreamErrorDetail(respBody)
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
...@@ -1994,6 +2105,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1994,6 +2105,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// Always record upstream context for Ops error logs, even when we will failover. // Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) {
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
...@@ -2089,6 +2216,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) ...@@ -2089,6 +2216,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
} }
} }
// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。
// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。
// 适用于所有走 Google 后端的平台(Antigravity、Gemini)。
func isGoogleProjectConfigError(lowerMsg string) bool {
// Google 间歇性 Bug:Project ID 有效但被临时识别失败
return strings.Contains(lowerMsg, "invalid project resource name")
}
// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长
const googleConfigErrorCooldown = 1 * time.Minute
// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁,
// 避免短时间内反复调度到同一个有问题的账号。
func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
until := time.Now().Add(googleConfigErrorCooldown)
reason := "400: invalid project resource name (auto temp-unschedule 1m)"
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
} else {
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
}
}
// emptyResponseCooldown 空流式响应的临时封禁时长
const emptyResponseCooldown = 1 * time.Minute
// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁,
// 避免短时间内反复调度到同一个返回空响应的账号。
func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
until := time.Now().Add(emptyResponseCooldown)
reason := "empty stream response (auto temp-unschedule 1m)"
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
} else {
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
}
}
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待 // sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待,false 表示 context 已取消 // 返回 true 表示正常完成等待,false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
...@@ -2105,10 +2270,12 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { ...@@ -2105,10 +2270,12 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
sleepFor = 0 sleepFor = 0
} }
timer := time.NewTimer(sleepFor)
select { select {
case <-ctx.Done(): case <-ctx.Done():
timer.Stop()
return false return false
case <-time.After(sleepFor): case <-timer.C:
return true return true
} }
} }
...@@ -2153,8 +2320,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) { ...@@ -2153,8 +2320,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
// antigravitySmartRetryInfo 智能重试所需的信息 // antigravitySmartRetryInfo 智能重试所需的信息
type antigravitySmartRetryInfo struct { type antigravitySmartRetryInfo struct {
RetryDelay time.Duration // 重试延迟时间 RetryDelay time.Duration // 重试延迟时间
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5") ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
IsModelCapacityExhausted bool // 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED)
} }
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息 // parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
...@@ -2269,31 +2437,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { ...@@ -2269,31 +2437,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
} }
return &antigravitySmartRetryInfo{ return &antigravitySmartRetryInfo{
RetryDelay: retryDelay, RetryDelay: retryDelay,
ModelName: modelName, ModelName: modelName,
IsModelCapacityExhausted: hasModelCapacityExhausted,
} }
} }
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试 // shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
// 返回: // 返回:
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold) // - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED
// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold // - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值
// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0) // - waitDuration: 等待时间
// - modelName: 限流的模型名称 // - modelName: 限流的模型名称
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) { // - isModelCapacityExhausted: 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED)
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) {
if account.Platform != PlatformAntigravity { if account.Platform != PlatformAntigravity {
return false, false, 0, "" return false, false, 0, "", false
} }
info := parseAntigravitySmartRetryInfo(respBody) info := parseAntigravitySmartRetryInfo(respBody)
if info == nil { if info == nil {
return false, false, 0, "" return false, false, 0, "", false
} }
// MODEL_CAPACITY_EXHAUSTED(模型容量不足):所有账号共享同一模型容量池
// 切换账号无意义,使用固定 1s 间隔重试
if info.IsModelCapacityExhausted {
return true, false, antigravityModelCapacityRetryWait, info.ModelName, true
}
// RATE_LIMIT_EXCEEDED(账号级限流):
// retryDelay >= 阈值:直接限流模型,不重试 // retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s
if info.RetryDelay >= antigravityRateLimitThreshold { if info.RetryDelay >= antigravityRateLimitThreshold {
return false, true, info.RetryDelay, info.ModelName return false, true, info.RetryDelay, info.ModelName, false
} }
// retryDelay < 阈值:智能重试 // retryDelay < 阈值:智能重试
...@@ -2302,7 +2479,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou ...@@ -2302,7 +2479,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
waitDuration = antigravitySmartRetryMinWait waitDuration = antigravitySmartRetryMinWait
} }
return true, false, waitDuration, info.ModelName return true, false, waitDuration, info.ModelName, false
} }
// handleModelRateLimitParams 模型级限流处理参数 // handleModelRateLimitParams 模型级限流处理参数
...@@ -2328,8 +2505,9 @@ type handleModelRateLimitResult struct { ...@@ -2328,8 +2505,9 @@ type handleModelRateLimitResult struct {
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用) // handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
// 仅处理 429/503,解析模型名和 retryDelay // 仅处理 429/503,解析模型名和 retryDelay
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试 // - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true(实际重试由 handleSmartRetry 处理)
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError // - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true,由调用方等待后重试
// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult { func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
if p.statusCode != 429 && p.statusCode != 503 { if p.statusCode != 429 && p.statusCode != 503 {
return &handleModelRateLimitResult{Handled: false} return &handleModelRateLimitResult{Handled: false}
...@@ -2340,7 +2518,17 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit ...@@ -2340,7 +2518,17 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
return &handleModelRateLimitResult{Handled: false} return &handleModelRateLimitResult{Handled: false}
} }
// < antigravityRateLimitThreshold: 等待后重试 // MODEL_CAPACITY_EXHAUSTED:模型容量不足,所有账号共享同一容量池
// 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理)
if info.IsModelCapacityExhausted {
log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)",
p.prefix, p.statusCode, info.ModelName)
return &handleModelRateLimitResult{
Handled: true,
}
}
// RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试
if info.RetryDelay < antigravityRateLimitThreshold { if info.RetryDelay < antigravityRateLimitThreshold {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v", logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v",
p.prefix, p.statusCode, info.ModelName, info.RetryDelay) p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
...@@ -2351,7 +2539,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit ...@@ -2351,7 +2539,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
} }
} }
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 // RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
s.setModelRateLimitAndClearSession(p, info) s.setModelRateLimitAndClearSession(p, info)
return &handleModelRateLimitResult{ return &handleModelRateLimitResult{
...@@ -2903,9 +3091,14 @@ returnResponse: ...@@ -2903,9 +3091,14 @@ returnResponse:
// 选择最后一个有效响应 // 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts) finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况 // 处理空响应情况 — 触发同账号重试 + failover 切换账号
if last == nil && lastWithParts == nil { if last == nil && lastWithParts == nil {
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response, no valid chunks received") logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
} }
// 如果收集到了图片 parts,需要合并到最终响应中 // 如果收集到了图片 parts,需要合并到最终响应中
...@@ -3123,6 +3316,21 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou ...@@ -3123,6 +3316,21 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
} }
// 检查错误透传规则
if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule(
c, account.Platform, upstreamStatus, body,
0, "", "",
); matched {
c.JSON(ptStatus, gin.H{
"type": "error",
"error": gin.H{"type": ptErrType, "message": ptErrMsg},
})
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
var statusCode int var statusCode int
var errType, errMsg string var errType, errMsg string
...@@ -3320,10 +3528,14 @@ returnResponse: ...@@ -3320,10 +3528,14 @@ returnResponse:
// 选择最后一个有效响应 // 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts) finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况 // 处理空响应情况 — 触发同账号重试 + failover 切换账号
if last == nil && lastWithParts == nil { if last == nil && lastWithParts == nil {
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response, no valid chunks received") logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
} }
// 将收集的所有 parts 合并到最终响应中 // 将收集的所有 parts 合并到最终响应中
......
...@@ -592,6 +592,75 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { ...@@ -592,6 +592,75 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
require.NotContains(t, body, "event: error") require.NotContains(t, body, "event: error")
} }
// TestHandleGeminiStreamingResponse_ThoughtsTokenCount
// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens
func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
// promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90
require.Equal(t, 90, result.usage.InputTokens)
// candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110
require.Equal(t, 110, result.usage.OutputTokens)
require.Equal(t, 10, result.usage.CacheReadInputTokens)
}
// TestHandleClaudeStreamingResponse_ThoughtsTokenCount
// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens
func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
// promptTokenCount=50 → InputTokens=50
require.Equal(t, 50, result.usage.InputTokens)
// candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35
require.Equal(t, 35, result.usage.OutputTokens)
}
// --- 流式客户端断开检测测试 --- // --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage // TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
......
...@@ -192,6 +192,43 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken ...@@ -192,6 +192,43 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr) return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
} }
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// 刷新 token
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 获取用户信息(email)
client := antigravity.NewClient(proxyURL)
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
tokenInfo.Email = userInfo.Email
}
// 获取 project_id(容错,失败不阻塞)
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func isNonRetryableAntigravityOAuthError(err error) bool { func isNonRetryableAntigravityOAuthError(err error) bool {
msg := err.Error() msg := err.Error()
nonRetryable := []string{ nonRetryable := []string{
...@@ -273,12 +310,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac ...@@ -273,12 +310,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
} }
client := antigravity.NewClient(proxyURL) client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil return loadResp.CloudAICompanionProject, nil
} }
if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil
} else if onboardErr != nil {
lastErr = onboardErr
continue
}
}
// 记录错误 // 记录错误
if err != nil { if err != nil {
lastErr = err lastErr = err
...@@ -292,6 +338,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac ...@@ -292,6 +338,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
} }
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
tierID := resolveDefaultTierID(loadRaw)
if tierID == "" {
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
}
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
}
return projectID, nil
}
func resolveDefaultTierID(loadRaw map[string]any) string {
if len(loadRaw) == 0 {
return ""
}
rawTiers, ok := loadRaw["allowedTiers"]
if !ok {
return ""
}
tiers, ok := rawTiers.([]any)
if !ok {
return ""
}
for _, rawTier := range tiers {
tier, ok := rawTier.(map[string]any)
if !ok {
continue
}
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
continue
}
if id, ok := tier["id"].(string); ok {
id = strings.TrimSpace(id)
if id != "" {
return id
}
}
}
return ""
}
// FillProjectID 仅获取 project_id,不刷新 OAuth token
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
}
// BuildAccountCredentials 构建账户凭证 // BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{ creds := map[string]any{
......
package service
import (
"testing"
)
func TestResolveDefaultTierID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
loadRaw map[string]any
want string
}{
{
name: "nil loadRaw",
loadRaw: nil,
want: "",
},
{
name: "missing allowedTiers",
loadRaw: map[string]any{
"paidTier": map[string]any{"id": "g1-pro-tier"},
},
want: "",
},
{
name: "empty allowedTiers",
loadRaw: map[string]any{"allowedTiers": []any{}},
want: "",
},
{
name: "tier missing id field",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"isDefault": true},
},
},
want: "",
},
{
name: "allowedTiers but no default",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": false},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "",
},
{
name: "default tier found",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": true},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "free-tier",
},
{
name: "default tier id with spaces",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": " standard-tier ", "isDefault": true},
},
},
want: "standard-tier",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := resolveDefaultTierID(tc.loadRaw)
if got != tc.want {
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
}
})
}
}
...@@ -92,7 +92,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i ...@@ -92,7 +92,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
return nil return nil
} }
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability oldAvailability := antigravity.DefaultURLAvailability
defer func() { defer func() {
...@@ -137,15 +139,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { ...@@ -137,15 +139,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.resp) require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }() defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode) require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.False(t, handleErrorCalled) require.True(t, handleErrorCalled)
require.Len(t, upstream.calls, 2) require.Len(t, upstream.calls, antigravityMaxRetries)
require.True(t, strings.HasPrefix(upstream.calls[0], base1)) for _, callURL := range upstream.calls {
require.True(t, strings.HasPrefix(upstream.calls[1], base2)) require.True(t, strings.HasPrefix(callURL, base1))
}
available := antigravity.DefaultURLAvailability.GetAvailableURLs() available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available) require.NotEmpty(t, available)
require.Equal(t, base2, available[0]) require.Equal(t, base1, available[0])
} }
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
...@@ -194,13 +197,14 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { ...@@ -194,13 +197,14 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
} }
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 // TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { // MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
repo := &stubAntigravityAccountRepo{} repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo} svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity} account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流 // 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
body := []byte(`{ body := []byte(`{
"error": { "error": {
"status": "UNAVAILABLE", "status": "UNAVAILABLE",
...@@ -213,13 +217,13 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { ...@@ -213,13 +217,13 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 应该触发模型限流 // MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
// 实际重试由 handleSmartRetry 处理
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.Handled) require.True(t, result.Handled)
require.NotNil(t, result.SwitchError) require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel) require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
require.Len(t, repo.modelRateLimitCalls, 1) require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
} }
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理) // TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
...@@ -307,11 +311,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) { ...@@ -307,11 +311,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
func TestParseAntigravitySmartRetryInfo(t *testing.T) { func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
body string body string
expectedDelay time.Duration expectedDelay time.Duration
expectedModel string expectedModel string
expectedNil bool expectedNil bool
expectedIsModelCapacityExhausted bool
}{ }{
{ {
name: "valid complete response with RATE_LIMIT_EXCEEDED", name: "valid complete response with RATE_LIMIT_EXCEEDED",
...@@ -374,8 +379,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) { ...@@ -374,8 +379,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
"message": "No capacity available for model gemini-3-pro-high on the server" "message": "No capacity available for model gemini-3-pro-high on the server"
} }
}`, }`,
expectedDelay: 39 * time.Second, expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high", expectedModel: "gemini-3-pro-high",
expectedIsModelCapacityExhausted: true,
}, },
{ {
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil", name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
...@@ -486,6 +492,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) { ...@@ -486,6 +492,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
if result.ModelName != tt.expectedModel { if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel) t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
} }
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
}) })
} }
} }
...@@ -497,13 +506,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { ...@@ -497,13 +506,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
apiKeyAccount := &Account{Type: AccountTypeAPIKey} apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct { tests := []struct {
name string name string
account *Account account *Account
body string body string
expectedShouldRetry bool expectedShouldRetry bool
expectedShouldRateLimit bool expectedShouldRateLimit bool
minWait time.Duration expectedIsModelCapacityExhausted bool
modelName string minWait time.Duration
modelName string
}{ }{
{ {
name: "OAuth account with short delay (< 7s) - smart retry", name: "OAuth account with short delay (< 7s) - smart retry",
...@@ -617,13 +627,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { ...@@ -617,13 +627,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
] ]
} }
}`, }`,
expectedShouldRetry: false, expectedShouldRetry: true,
expectedShouldRateLimit: true, expectedShouldRateLimit: false,
minWait: 39 * time.Second, expectedIsModelCapacityExhausted: true,
modelName: "gemini-3-pro-high", minWait: 1 * time.Second,
modelName: "gemini-3-pro-high",
}, },
{ {
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit", name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
account: oauthAccount, account: oauthAccount,
body: `{ body: `{
"error": { "error": {
...@@ -635,10 +646,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { ...@@ -635,10 +646,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
"message": "No capacity available for model gemini-2.5-flash on the server" "message": "No capacity available for model gemini-2.5-flash on the server"
} }
}`, }`,
expectedShouldRetry: false, expectedShouldRetry: true,
expectedShouldRateLimit: true, expectedShouldRateLimit: false,
minWait: 30 * time.Second, expectedIsModelCapacityExhausted: true,
modelName: "gemini-2.5-flash", minWait: 1 * time.Second,
modelName: "gemini-2.5-flash",
}, },
{ {
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit", name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
...@@ -662,13 +674,16 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { ...@@ -662,13 +674,16 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry { if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry) t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
} }
if shouldRateLimit != tt.expectedShouldRateLimit { if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit) t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
} }
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
if shouldRetry { if shouldRetry {
if wait < tt.minWait { if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait) t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
...@@ -921,6 +936,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) { ...@@ -921,6 +936,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
} }
} }
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
defer func() {
antigravity.BaseURLs = oldBaseURLs
}()
prodURL := "https://prod.test"
dailyURL := "https://daily.test"
antigravity.BaseURLs = []string{dailyURL, prodURL}
resolved := resolveAntigravityForwardBaseURL()
require.Equal(t, dailyURL, resolved)
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) { func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{ err := &AntigravityAccountSwitchError{
OriginalAccountID: 789, OriginalAccountID: 789,
......
...@@ -153,13 +153,14 @@ func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *te ...@@ -153,13 +153,14 @@ func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *te
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
} }
// 503 + 39s >= 7s 阈值 // 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel)
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 503,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
] ]
} }
...@@ -339,13 +340,14 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi ...@@ -339,13 +340,14 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit // TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流 // 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) { func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
failRespBody := `{ failRespBody := `{
"error": { "error": {
"code": 503, "code": 503,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
] ]
} }
...@@ -371,9 +373,9 @@ func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *t ...@@ -371,9 +373,9 @@ func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *t
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 503,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
] ]
} }
......
...@@ -294,8 +294,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test ...@@ -294,8 +294,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)") require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
} }
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError // TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) { // MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
repo := &stubAntigravityAccountRepo{} repo := &stubAntigravityAccountRepo{}
account := &Account{ account := &Account{
ID: 3, ID: 3,
...@@ -304,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi ...@@ -304,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
} }
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值 // 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s)
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 503,
...@@ -322,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi ...@@ -322,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
Body: io.NopCloser(bytes.NewReader(respBody)), Body: io.NopCloser(bytes.NewReader(respBody)),
} }
// mock: 第 1 次重试返回 200 成功
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
},
errors: []error{nil},
}
params := antigravityRetryLoopParams{ params := antigravityRetryLoopParams{
ctx: context.Background(), ctx: context.Background(),
prefix: "[test]", prefix: "[test]",
...@@ -330,6 +339,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi ...@@ -330,6 +339,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
action: "generateContent", action: "generateContent",
body: []byte(`{"input":"test"}`), body: []byte(`{"input":"test"}`),
accountRepo: repo, accountRepo: repo,
httpUpstream: upstream,
isStickySession: true, isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil return nil
...@@ -343,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi ...@@ -343,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action) require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp) require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.err) require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted") require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置 // 不应设置模型限流
require.Len(t, repo.modelRateLimitCalls, 1) require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) require.Len(t, upstream.calls, 1, "should have made one retry call before success")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
// 立即取消上下文,验证重试循环能正确退出
ctx, cancel := context.WithCancel(context.Background())
cancel()
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
require.Nil(t, result.switchError, "should not return switchError on context cancel")
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
} }
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑 // TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
...@@ -1129,20 +1190,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t ...@@ -1129,20 +1190,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
} }
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession // TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 // 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) { func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{ failRespBody := `{
"error": { "error": {
"code": 503, "code": 429,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
] ]
} }
}` }`
failResp := &http.Response{ failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable, StatusCode: http.StatusTooManyRequests,
Header: http.Header{}, Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)), Body: io.NopCloser(strings.NewReader(failRespBody)),
} }
...@@ -1162,16 +1223,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession ...@@ -1162,16 +1223,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 429,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
] ]
} }
}`) }`)
resp := &http.Response{ resp := &http.Response{
StatusCode: http.StatusServiceUnavailable, StatusCode: http.StatusTooManyRequests,
Header: http.Header{}, Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)), Body: io.NopCloser(bytes.NewReader(respBody)),
} }
......
...@@ -7,12 +7,14 @@ import ( ...@@ -7,12 +7,14 @@ import (
"log/slog" "log/slog"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
const ( const (
antigravityTokenRefreshSkew = 3 * time.Minute antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
) )
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) // AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
...@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct { ...@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenCache AntigravityTokenCache tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
} }
func NewAntigravityTokenProvider( func NewAntigravityTokenProvider(
...@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err != nil { if err != nil {
return "", err return "", err
} }
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) p.mergeCredentials(account, tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
} }
...@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil { if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
...@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil return accessToken, nil
} }
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string { func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
......
...@@ -61,6 +61,11 @@ func applyErrorPassthroughRule( ...@@ -61,6 +61,11 @@ func applyErrorPassthroughRule(
errMsg = *rule.CustomMessage errMsg = *rule.CustomMessage
} }
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
if rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error" errType = "upstream_error"
return status, errType, errMsg, true return status, errType, errMsg, true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment