Commit fd43be8d authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突



- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parents 792bef61 836ba14b
......@@ -36,8 +36,8 @@ type UsageLogRepository interface {
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
......
//go:build unit
package service
import (
"testing"
)
func TestMatchWildcard(t *testing.T) {
tests := []struct {
name string
pattern string
str string
expected bool
}{
// 精确匹配
{"exact match", "claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"exact mismatch", "claude-sonnet-4-5", "claude-opus-4-5", false},
// 通配符匹配
{"wildcard prefix match", "claude-*", "claude-sonnet-4-5", true},
{"wildcard prefix match 2", "claude-*", "claude-opus-4-5-thinking", true},
{"wildcard prefix mismatch", "claude-*", "gemini-3-flash", false},
{"wildcard partial match", "gemini-3*", "gemini-3-flash", true},
{"wildcard partial match 2", "gemini-3*", "gemini-3-pro-image", true},
{"wildcard partial mismatch", "gemini-3*", "gemini-2.5-flash", false},
// 边界情况
{"empty pattern exact", "", "", true},
{"empty pattern mismatch", "", "claude", false},
{"single star", "*", "anything", true},
{"star at end only", "abc*", "abcdef", true},
{"star at end empty suffix", "abc*", "abc", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.str)
if result != tt.expected {
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.str, result, tt.expected)
}
})
}
}
func TestMatchWildcardMapping(t *testing.T) {
tests := []struct {
name string
mapping map[string]string
requestedModel string
expected string
}{
// 精确匹配优先于通配符
{
name: "exact match takes precedence",
mapping: map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-exact",
"claude-*": "claude-default",
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-exact",
},
// 最长通配符优先
{
name: "longer wildcard takes precedence",
mapping: map[string]string{
"claude-*": "claude-default",
"claude-sonnet-*": "claude-sonnet-default",
"claude-sonnet-4*": "claude-sonnet-4-series",
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-series",
},
// 单个通配符
{
name: "single wildcard",
mapping: map[string]string{
"claude-*": "claude-mapped",
},
requestedModel: "claude-opus-4-5",
expected: "claude-mapped",
},
// 无匹配返回原始模型
{
name: "no match returns original",
mapping: map[string]string{
"claude-*": "claude-mapped",
},
requestedModel: "gemini-3-flash",
expected: "gemini-3-flash",
},
// 空映射返回原始模型
{
name: "empty mapping returns original",
mapping: map[string]string{},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
// Gemini 模型映射
{
name: "gemini wildcard mapping",
mapping: map[string]string{
"gemini-3*": "gemini-3-pro-high",
"gemini-2.5*": "gemini-2.5-flash",
},
requestedModel: "gemini-3-flash-preview",
expected: "gemini-3-pro-high",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
if result != tt.expected {
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
}
})
}
}
func TestAccountIsModelSupported(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expected bool
}{
// 无映射 = 允许所有
{
name: "no mapping allows all",
credentials: nil,
requestedModel: "any-model",
expected: true,
},
{
name: "empty mapping allows all",
credentials: map[string]any{},
requestedModel: "any-model",
expected: true,
},
// 精确匹配
{
name: "exact match supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "target-model",
},
},
requestedModel: "claude-sonnet-4-5",
expected: true,
},
{
name: "exact match not supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "target-model",
},
},
requestedModel: "claude-opus-4-5",
expected: false,
},
// 通配符匹配
{
name: "wildcard match supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
},
},
requestedModel: "claude-opus-4-5-thinking",
expected: true,
},
{
name: "wildcard match not supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
},
},
requestedModel: "gemini-3-flash",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
result := account.IsModelSupported(tt.requestedModel)
if result != tt.expected {
t.Errorf("IsModelSupported(%q) = %v, want %v", tt.requestedModel, result, tt.expected)
}
})
}
}
func TestAccountGetMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expected string
}{
// 无映射 = 返回原始模型
{
name: "no mapping returns original",
credentials: nil,
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
// 精确匹配
{
name: "exact match",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "target-model",
},
},
requestedModel: "claude-sonnet-4-5",
expected: "target-model",
},
// 通配符匹配(最长优先)
{
name: "wildcard longest match",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-default",
"claude-sonnet-*": "claude-sonnet-mapped",
},
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-mapped",
},
// 无匹配返回原始模型
{
name: "no match returns original",
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-*": "gemini-mapped",
},
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
result := account.GetMappedModel(tt.requestedModel)
if result != tt.expected {
t.Errorf("GetMappedModel(%q) = %q, want %q", tt.requestedModel, result, tt.expected)
}
})
}
}
......@@ -56,6 +56,7 @@ type AdminService interface {
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
......@@ -179,6 +180,8 @@ type CreateAccountInput struct {
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
// SkipDefaultGroupBind prevents auto-binding to platform default group when GroupIDs is empty.
SkipDefaultGroupBind bool
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
......@@ -1076,7 +1079,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
if len(groupIDs) == 0 && !input.SkipDefaultGroupBind {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
......@@ -1444,6 +1447,10 @@ func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, erro
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
return s.proxyRepo.ListByIDs(ctx, ids)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &Proxy{
Name: input.Name,
......
......@@ -187,6 +187,10 @@ func (s *proxyRepoStub) GetByID(ctx context.Context, id int64) (*Proxy, error) {
panic("unexpected GetByID call")
}
func (s *proxyRepoStub) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
panic("unexpected ListByIDs call")
}
func (s *proxyRepoStub) Update(ctx context.Context, proxy *Proxy) error {
panic("unexpected Update call")
}
......
......@@ -19,49 +19,65 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
antigravityStickySessionTTL = time.Hour
antigravityDefaultMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
antigravityStickySessionTTL = time.Hour
antigravityMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
// 限流相关常量
// antigravityRateLimitThreshold 限流等待/切换阈值
// - 智能重试:retryDelay < 此阈值时等待后重试,>= 此阈值时直接限流模型
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
antigravityRateLimitThreshold = 7 * time.Second
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// Google RPC 状态和类型常量
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
googleRPCStatusUnavailable = "UNAVAILABLE"
googleRPCTypeRetryInfo = "type.googleapis.com/google.rpc.RetryInfo"
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
)
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
// 匹配时使用 strings.Contains,无需完全匹配
var antigravityPassthroughErrorMessages = []string{
"prompt is too long",
}
const (
antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
ctx context.Context
prefix string
account *Account
proxyURL string
accessToken string
action string
body []byte
quotaScope AntigravityQuotaScope
maxRetries int
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
// AntigravityAccountSwitchError 账号切换信号
// 当账号限流时间超过阈值时,通知上层切换账号
type AntigravityAccountSwitchError struct {
OriginalAccountID int64
RateLimitedModel string
IsStickySession bool // 是否为粘性会话切换(决定是否缓存计费)
}
// antigravityRetryLoopResult 重试循环的结果
type antigravityRetryLoopResult struct {
resp *http.Response
func (e *AntigravityAccountSwitchError) Error() string {
return fmt.Sprintf("account %d model %s rate limited, need switch",
e.OriginalAccountID, e.RateLimitedModel)
}
// IsAntigravityAccountSwitchError 检查错误是否为账号切换信号
func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, bool) {
var switchErr *AntigravityAccountSwitchError
if errors.As(err, &switchErr) {
return switchErr, true
}
return nil, false
}
// PromptTooLongError 表示上游明确返回 prompt too long
......@@ -75,17 +91,207 @@ func (e *PromptTooLongError) Error() string {
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
}
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
ctx context.Context
prefix string
account *Account
proxyURL string
accessToken string
action string
body []byte
quotaScope AntigravityQuotaScope
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
accountRepo AccountRepository // 用于智能重试的模型级别限流
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult
requestedModel string // 用于限流检查的原始请求模型
isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断)
groupID int64 // 用于模型级限流时清除粘性会话
sessionHash string // 用于模型级限流时清除粘性会话
}
// antigravityRetryLoopResult 重试循环的结果
type antigravityRetryLoopResult struct {
resp *http.Response
}
// smartRetryAction 智能重试的处理结果
type smartRetryAction int
const (
smartRetryActionContinue smartRetryAction = iota // 继续默认重试逻辑
smartRetryActionBreakWithResp // 结束循环并返回 resp
smartRetryActionContinueURL // 继续 URL fallback 循环
)
// smartRetryResult 智能重试的结果
type smartRetryResult struct {
action smartRetryAction
resp *http.Response
err error
switchError *AntigravityAccountSwitchError // 模型限流时返回账号切换信号
}
// handleSmartRetry 处理 OAuth 账号的智能重试逻辑
// 将 429/503 限流处理逻辑抽取为独立函数,减少 antigravityRetryLoop 的复杂度
func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult {
// "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429)
if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
return &smartRetryResult{action: smartRetryActionContinueURL}
}
// 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel {
log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)",
p.prefix, resp.StatusCode, modelName, p.account.ID)
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID)
} else {
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
// 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
switchError: &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: modelName,
IsStickySession: p.isStickySession,
},
}
}
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
if shouldSmartRetry {
var lastRetryResp *http.Response
var lastRetryBody []byte
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-time.After(waitDuration):
}
// 智能重试:创建新请求
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
},
}
}
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
}
// 网络错误时,继续重试
if retryErr != nil || retryResp == nil {
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
continue
}
// 重试失败,关闭之前的响应
if lastRetryResp != nil {
_ = lastRetryResp.Body.Close()
}
lastRetryResp = retryResp
if retryResp != nil {
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
}
// 解析新的重试信息,用于下次重试的等待时间
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newShouldRetry && newWaitDuration > 0 {
waitDuration = newWaitDuration
}
}
}
// 所有重试都失败,限流当前模型并切换账号
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID)
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
if p.accountRepo != nil && modelName != "" {
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
} else {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration)
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
}
// 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
switchError: &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: modelName,
IsStickySession: p.isStickySession,
},
}
}
// 未触发智能重试,继续默认重试逻辑
return &smartRetryResult{action: smartRetryActionContinue}
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
baseURLs := antigravity.ForwardBaseURLs()
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs)
if len(availableURLs) == 0 {
availableURLs = baseURLs
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:如果账号已限流,根据剩余时间决定等待或切换
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
if remaining < antigravityRateLimitThreshold {
// 限流剩余时间较短,等待后继续
log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
select {
case <-p.ctx.Done():
return nil, p.ctx.Err()
case <-time.After(remaining):
}
} else {
// 限流剩余时间较长,返回账号切换信号
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
}
}
}
maxRetries := p.maxRetries
if maxRetries <= 0 {
maxRetries = antigravityDefaultMaxRetries
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs
}
var resp *http.Response
......@@ -105,7 +311,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL
for attempt := 1; attempt <= maxRetries; attempt++ {
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
......@@ -124,6 +330,9 @@ urlFallbackLoop:
}
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err == nil && resp == nil {
err = errors.New("upstream returned nil response")
}
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
......@@ -138,8 +347,8 @@ urlFallbackLoop:
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < maxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err)
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
......@@ -151,19 +360,31 @@ urlFallbackLoop:
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
// 429 限流处理:区分 URL 级别限流和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests {
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// "Resource has been exhausted" 是 URL 级别限流,切换 URL
if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
// 尝试智能重试处理(OAuth 账号专用)
smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs)
switch smartResult.action {
case smartRetryActionContinueURL:
continue urlFallbackLoop
case smartRetryActionBreakWithResp:
if smartResult.err != nil {
return nil, smartResult.err
}
// 模型限流时返回切换账号信号
if smartResult.switchError != nil {
return nil, smartResult.switchError
}
resp = smartResult.resp
break urlFallbackLoop
}
// smartRetryActionContinue: 继续默认重试逻辑
// 账户/模型配额限流,重试 3 次(指数退避)
if attempt < maxRetries {
// 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败)
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
......@@ -176,7 +397,7 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200))
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
......@@ -185,8 +406,8 @@ urlFallbackLoop:
}
// 重试用尽,标记账户限流
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope)
log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200))
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
......@@ -195,12 +416,12 @@ urlFallbackLoop:
break urlFallbackLoop
}
// 其他可重试错误
// 其他可重试错误(不包括 429 和 503,因为上面已处理)
if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < maxRetries {
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
......@@ -213,7 +434,7 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500))
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
......@@ -301,71 +522,34 @@ func logPrefix(sessionID, accountName string) string {
return fmt.Sprintf("[antigravity-Forward] account=%s", accountName)
}
// Antigravity 直接支持的模型(精确匹配透传)
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true,
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-image": true,
}
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// gemini-2.5 → gemini-3 映射(长前缀优先)
{"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"},
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
type AntigravityGatewayService struct {
accountRepo AccountRepository
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
settingService *SettingService
accountRepo AccountRepository
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
settingService *SettingService
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
schedulerSnapshot *SchedulerSnapshotService
}
func NewAntigravityGatewayService(
accountRepo AccountRepository,
_ GatewayCache,
cache GatewayCache,
schedulerSnapshot *SchedulerSnapshotService,
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
settingService *SettingService,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
settingService: settingService,
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
settingService: settingService,
cache: cache,
schedulerSnapshot: schedulerSnapshot,
}
}
......@@ -374,33 +558,80 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
return s.tokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 账户级映射(用户自定义优先)
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
// getLogConfig 获取上游错误日志配置
// 返回是否记录日志体和最大字节数
func (s *AntigravityGatewayService) getLogConfig() (logBody bool, maxBytes int) {
maxBytes = 2048 // 默认值
if s.settingService == nil || s.settingService.cfg == nil {
return false, maxBytes
}
cfg := s.settingService.cfg.Gateway
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
}
return cfg.LogUpstreamErrorBody, maxBytes
}
// 2. 直接支持的模型透传
if antigravitySupportedModels[requestedModel] {
return requestedModel
// getUpstreamErrorDetail 获取上游错误详情(用于日志记录)
func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string {
logBody, maxBytes := s.getLogConfig()
if !logBody {
return ""
}
return truncateString(string(body), maxBytes)
}
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
for _, pm := range antigravityPrefixMapping {
if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
// mapAntigravityModel 获取映射后的模型名
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping)
// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号
func mapAntigravityModel(account *Account, requestedModel string) string {
if account == nil {
return ""
}
// 获取映射表(未配置时自动使用 DefaultAntigravityModelMapping)
mapping := account.GetModelMapping()
if len(mapping) == 0 {
return "" // 无映射配置(非 Antigravity 平台)
}
// 通过映射表查询(支持精确匹配 + 通配符)
mapped := account.GetMappedModel(requestedModel)
// 判断是否映射成功(mapped != requestedModel 说明找到了映射规则)
if mapped != requestedModel {
return mapped
}
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
if strings.HasPrefix(requestedModel, "gemini-") {
// 如果 mapped == requestedModel,检查是否在映射表中配置(精确或通配符)
// 这区分两种情况:
// 1. 映射表中有 "model-a": "model-a"(显式透传)→ 返回 model-a
// 2. 通配符匹配 "claude-*": "claude-sonnet-4-5" 恰好目标等于请求名 → 返回 model-a
// 3. 映射表中没有 model-a 的配置 → 返回空(不支持)
if account.IsModelSupported(requestedModel) {
return requestedModel
}
// 5. 默认值
return "claude-sonnet-4-5"
// 未在映射表中配置的模型,返回空字符串(不支持)
return ""
}
// getMappedModel 获取映射后的模型名
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
return mapAntigravityModel(account, requestedModel)
}
// applyThinkingModelSuffix 根据 thinking 配置调整模型名
// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking
func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string {
if !thinkingEnabled {
return mappedModel
}
if mappedModel == "claude-sonnet-4-5" {
return "claude-sonnet-4-5-thinking"
}
return mappedModel
}
// IsModelSupported 检查模型是否被支持
......@@ -419,11 +650,6 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 上游透传账号使用专用测试方法
if account.Type == AccountTypeUpstream {
return s.testUpstreamConnection(ctx, account, modelID)
}
// 获取 token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
......@@ -438,6 +664,9 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 模型映射
mappedModel := s.getMappedModel(account, modelID)
if mappedModel == "" {
return nil, fmt.Errorf("model %s not in whitelist", modelID)
}
// 构建请求体
var requestBody []byte
......@@ -518,87 +747,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, lastErr
}
// testUpstreamConnection 测试上游透传账号连接
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, errors.New("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 使用 Claude 模型进行测试
if modelID == "" {
modelID = "claude-sonnet-4-20250514"
}
// 构建最小测试请求
testReq := map[string]any{
"model": modelID,
"max_tokens": 1,
"messages": []map[string]any{
{"role": "user", "content": "."},
},
}
requestBody, err := json.Marshal(testReq)
if err != nil {
return nil, fmt.Errorf("构建请求失败: %w", err)
}
// 构建 HTTP 请求
upstreamURL := baseURL + "/v1/messages"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL)
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 提取响应文本
var respData map[string]any
text := ""
if json.Unmarshal(respBody, &respData) == nil {
if content, ok := respData["content"].([]any); ok && len(content) > 0 {
if block, ok := content[0].(map[string]any); ok {
if t, ok := block["text"].(string); ok {
text = t
}
}
}
}
return &TestConnectionResult{
Text: text,
MappedModel: modelID,
}, nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
......@@ -649,10 +797,6 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
}
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil {
opts.EnableMCPXML = group.MCPXMLInject
}
return opts
}
......@@ -820,12 +964,7 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
}
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
// 上游透传账号直接转发,不走 OAuth token 刷新
if account.Type == AccountTypeUpstream {
return s.ForwardUpstream(ctx, c, account, body)
}
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
......@@ -833,29 +972,30 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 解析 Claude 请求
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body")
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model")
}
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
billingModel := originalModel
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel
if mappedModel == "" {
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
}
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
loadModel := mappedModel
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 获取 access_token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Failed to get upstream access token")
}
// 获取 project_id(部分账户类型可能没有)
......@@ -875,30 +1015,46 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
if err != nil {
return nil, fmt.Errorf("transform request: %w", err)
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request")
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if s.cache != nil {
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel)
}
// 执行带重试的请求
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: geminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
maxRetries: maxRetries,
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: geminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession, // Forward 由上层判断粘性会话
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
})
if err != nil {
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
resp := result.resp
......@@ -913,15 +1069,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
logBody, maxBytes := s.getLogConfig()
upstreamDetail := s.getUpstreamErrorDetail(respBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
......@@ -960,20 +1109,24 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if txErr != nil {
continue
}
retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
maxRetries: maxRetries,
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
})
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
......@@ -1049,22 +1202,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
if resp.StatusCode == http.StatusBadRequest {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
}
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
upstreamDetail := s.getUpstreamErrorDetail(respBody)
logBody, maxBytes := s.getLogConfig()
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
......@@ -1082,20 +1227,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Body: respBody,
}
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
upstreamDetail := s.getUpstreamErrorDetail(respBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
......@@ -1143,7 +1281,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: billingModel, // 计费模型(可按映射模型覆盖)
Model: originalModel, // 使用原始模型用于计费和日志
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......@@ -1168,21 +1306,38 @@ func isSignatureRelatedError(respBody []byte) bool {
return true
}
// Detect thinking block modification errors:
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
return true
}
return false
}
// isPromptTooLongError 检测是否为 prompt too long 错误
func isPromptTooLongError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "prompt is too long")
return strings.Contains(msg, "prompt is too long") ||
strings.Contains(msg, "request is too long") ||
strings.Contains(msg, "context length exceeded") ||
strings.Contains(msg, "max_tokens")
}
// isPassthroughErrorMessage 检查错误消息是否在透传白名单中
func isPassthroughErrorMessage(msg string) bool {
lower := strings.ToLower(msg)
for _, pattern := range antigravityPassthroughErrorMessages {
if strings.Contains(lower, pattern) {
return true
}
}
return false
}
// getPassthroughOrDefault 若消息在白名单内则返回原始消息,否则返回默认消息
func getPassthroughOrDefault(upstreamMsg, defaultMsg string) string {
if isPassthroughErrorMessage(upstreamMsg) {
return upstreamMsg
}
return defaultMsg
}
func extractAntigravityErrorMessage(body []byte) string {
......@@ -1191,41 +1346,15 @@ func extractAntigravityErrorMessage(body []byte) string {
return ""
}
parseNestedMessage := func(msg string) string {
trimmed := strings.TrimSpace(msg)
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
return ""
}
var nested map[string]any
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
return ""
}
if errObj, ok := nested["error"].(map[string]any); ok {
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
return ""
}
// Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg
}
}
// Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg
}
......@@ -1521,7 +1650,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude, 0, "", false)
}
// 透传上游错误
......@@ -1656,7 +1785,7 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
......@@ -1686,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(time.Now()),
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
default:
......@@ -1694,20 +1823,17 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
mappedModel := s.getMappedModel(account, originalModel)
billingModel := originalModel
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel
if mappedModel == "" {
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
}
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to get upstream access token")
}
// 获取 project_id(部分账户类型可能没有)
......@@ -1719,17 +1845,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL = account.Proxy.URL()
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
filteredBody, err := filterEmptyPartsFromGeminiRequest(body)
if err != nil {
log.Printf("[Antigravity] Failed to filter empty parts: %v", err)
filteredBody = body
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody)
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
if err != nil {
return nil, err
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Invalid request body")
}
// 清理 Schema
......@@ -1743,30 +1862,46 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
return nil, err
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if s.cache != nil {
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
}
// 执行带重试的请求
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: wrappedBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
maxRetries: maxRetries,
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: wrappedBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession, // ForwardGemini 由上层判断粘性会话
groupID: 0, // ForwardGemini 方法没有 groupID,由上层处理粘性会话清除
sessionHash: "", // ForwardGemini 方法没有 sessionHash,由上层处理粘性会话清除
})
if err != nil {
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
resp := result.resp
......@@ -1822,19 +1957,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(unwrappedForOps), maxBytes)
}
upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps)
// Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
......@@ -1913,7 +2039,7 @@ handleSuccess:
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: billingModel,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......@@ -1955,104 +2081,347 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func antigravityUseScopeRateLimit() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
// 默认开启按配额域限流,只有明确设置为禁用值时才关闭
if v == "0" || v == "false" || v == "no" || v == "off" {
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
// 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false)
func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, accountID int64, modelName, prefix string, statusCode int, resetAt time.Time, afterSmartRetry bool) bool {
if repo == nil || modelName == "" {
return false
}
// 直接使用官方模型 ID 作为 key,不再转换为 scope
if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
return false
}
if afterSmartRetry {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
} else {
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
}
return true
}
func antigravityHasAccountSwitch(ctx context.Context) bool {
if ctx == nil {
return false
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
if raw == "" {
return 0, false
}
if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok {
return v > 0
seconds, err := strconv.Atoi(raw)
if err != nil || seconds <= 0 {
return 0, false
}
return false
return time.Duration(seconds) * time.Second, true
}
func antigravityMaxRetries() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv))
if raw == "" {
return antigravityDefaultMaxRetries
// antigravitySmartRetryInfo 智能重试所需的信息
type antigravitySmartRetryInfo struct {
RetryDelay time.Duration // 重试延迟时间
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
}
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
// 返回解析结果,如果解析失败或不满足条件返回 nil
//
// 支持两种情况:
// 1. 429 RESOURCE_EXHAUSTED + RATE_LIMIT_EXCEEDED:
// - error.status == "RESOURCE_EXHAUSTED"
// - error.details[].reason == "RATE_LIMIT_EXCEEDED"
//
// 2. 503 UNAVAILABLE + MODEL_CAPACITY_EXHAUSTED:
// - error.status == "UNAVAILABLE"
// - error.details[].reason == "MODEL_CAPACITY_EXHAUSTED"
//
// 必须满足以下条件才会返回有效值:
// - error.details[] 中存在 @type == "type.googleapis.com/google.rpc.RetryInfo" 的元素
// - 该元素包含 retryDelay 字段,格式为 "数字s"(如 "0.201506475s")
func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
errObj, ok := parsed["error"].(map[string]any)
if !ok {
return nil
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityDefaultMaxRetries
// 检查 status 是否符合条件
// 情况1: 429 RESOURCE_EXHAUSTED (需要进一步检查 reason == RATE_LIMIT_EXCEEDED)
// 情况2: 503 UNAVAILABLE (需要进一步检查 reason == MODEL_CAPACITY_EXHAUSTED)
status, _ := errObj["status"].(string)
isResourceExhausted := status == googleRPCStatusResourceExhausted
isUnavailable := status == googleRPCStatusUnavailable
if !isResourceExhausted && !isUnavailable {
return nil
}
details, ok := errObj["details"].([]any)
if !ok {
return nil
}
var retryDelay time.Duration
var modelName string
var hasRateLimitExceeded bool // 429 需要此 reason
var hasModelCapacityExhausted bool // 503 需要此 reason
for _, d := range details {
dm, ok := d.(map[string]any)
if !ok {
continue
}
atType, _ := dm["@type"].(string)
// 从 ErrorInfo 提取模型名称和 reason
if atType == googleRPCTypeErrorInfo {
if meta, ok := dm["metadata"].(map[string]any); ok {
if model, ok := meta["model"].(string); ok {
modelName = model
}
}
// 检查 reason
if reason, ok := dm["reason"].(string); ok {
if reason == googleRPCReasonModelCapacityExhausted {
hasModelCapacityExhausted = true
}
if reason == googleRPCReasonRateLimitExceeded {
hasRateLimitExceeded = true
}
}
continue
}
// 从 RetryInfo 提取重试延迟
if atType == googleRPCTypeRetryInfo {
delay, ok := dm["retryDelay"].(string)
if !ok || delay == "" {
continue
}
// 使用 time.ParseDuration 解析,支持所有 Go duration 格式
// 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等
dur, err := time.ParseDuration(delay)
if err != nil {
log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
continue
}
retryDelay = dur
}
}
// 验证条件
// 情况1: RESOURCE_EXHAUSTED 需要有 RATE_LIMIT_EXCEEDED reason
// 情况2: UNAVAILABLE 需要有 MODEL_CAPACITY_EXHAUSTED reason
if isResourceExhausted && !hasRateLimitExceeded {
return nil
}
if isUnavailable && !hasModelCapacityExhausted {
return nil
}
// 必须有模型名才返回有效结果
if modelName == "" {
return nil
}
// 如果上游未提供 retryDelay,使用默认限流时间
if retryDelay <= 0 {
retryDelay = antigravityDefaultRateLimitDuration
}
return &antigravitySmartRetryInfo{
RetryDelay: retryDelay,
ModelName: modelName,
}
return value
}
func antigravityMaxRetriesAfterSwitch() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv))
if raw == "" {
return antigravityMaxRetries()
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
// 返回:
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold)
// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold)
// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0)
// - modelName: 限流的模型名称
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
if account.Platform != PlatformAntigravity {
return false, false, 0, ""
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityMaxRetries()
info := parseAntigravitySmartRetryInfo(respBody)
if info == nil {
return false, false, 0, ""
}
// retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟
if info.RetryDelay >= antigravityRateLimitThreshold {
return false, true, 0, info.ModelName
}
return value
// retryDelay < 阈值:智能重试
waitDuration = info.RetryDelay
if waitDuration < antigravitySmartRetryMinWait {
waitDuration = antigravitySmartRetryMinWait
}
return true, false, waitDuration, info.ModelName
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置
func antigravityMaxRetriesForModel(model string, afterSwitch bool) int {
var envKey string
if strings.HasPrefix(model, "claude-") {
envKey = antigravityMaxRetriesClaudeEnv
} else if isImageGenerationModel(model) {
envKey = antigravityMaxRetriesGeminiImageEnv
} else if strings.HasPrefix(model, "gemini-") {
envKey = antigravityMaxRetriesGeminiTextEnv
// handleModelRateLimitParams 模型级限流处理参数
type handleModelRateLimitParams struct {
ctx context.Context
prefix string
account *Account
statusCode int
body []byte
cache GatewayCache
groupID int64
sessionHash string
isStickySession bool
}
// handleModelRateLimitResult 模型级限流处理结果
type handleModelRateLimitResult struct {
Handled bool // 是否已处理
ShouldRetry bool // 是否等待后重试
WaitDuration time.Duration // 等待时间
SwitchError *AntigravityAccountSwitchError // 账号切换错误
}
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
// 仅处理 429/503,解析模型名和 retryDelay
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
if p.statusCode != 429 && p.statusCode != 503 {
return &handleModelRateLimitResult{Handled: false}
}
if envKey != "" {
if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" {
if value, err := strconv.Atoi(raw); err == nil && value > 0 {
return value
}
info := parseAntigravitySmartRetryInfo(p.body)
if info == nil || info.ModelName == "" {
return &handleModelRateLimitResult{Handled: false}
}
// < antigravityRateLimitThreshold: 等待后重试
if info.RetryDelay < antigravityRateLimitThreshold {
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
return &handleModelRateLimitResult{
Handled: true,
ShouldRetry: true,
WaitDuration: info.RetryDelay,
}
}
if afterSwitch {
return antigravityMaxRetriesAfterSwitch()
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
s.setModelRateLimitAndClearSession(p, info)
return &handleModelRateLimitResult{
Handled: true,
SwitchError: &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: info.ModelName,
IsStickySession: p.isStickySession,
},
}
return antigravityMaxRetries()
}
func antigravityUseMappedModelForBilling() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话
func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) {
resetAt := time.Now().Add(info.RetryDelay)
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay)
// 设置模型限流状态(数据库)
if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil {
log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
}
// 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中
s.updateAccountModelRateLimitInCache(p.ctx, p.account, info.ModelName, resetAt)
// 清除粘性会话绑定
if p.cache != nil && p.sessionHash != "" {
_ = p.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
}
}
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
if raw == "" {
return 0, false
// updateAccountModelRateLimitInCache 立即更新 Redis 中账号的模型限流状态
func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx context.Context, account *Account, modelKey string, resetAt time.Time) {
if s.schedulerSnapshot == nil || account == nil || modelKey == "" {
return
}
seconds, err := strconv.Atoi(raw)
if err != nil || seconds <= 0 {
return 0, false
// 更新账号对象的 Extra 字段
if account.Extra == nil {
account.Extra = make(map[string]any)
}
limits, _ := account.Extra["model_rate_limits"].(map[string]any)
if limits == nil {
limits = make(map[string]any)
account.Extra["model_rate_limits"] = limits
}
limits[modelKey] = map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
// 更新 Redis 快照
if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil {
log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
}
return time.Duration(seconds) * time.Second, true
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
func (s *AntigravityGatewayService) handleUpstreamError(
ctx context.Context, prefix string, account *Account,
statusCode int, headers http.Header, body []byte,
quotaScope AntigravityQuotaScope,
groupID int64, sessionHash string, isStickySession bool,
) *handleModelRateLimitResult {
// ✨ 模型级限流处理(在原有逻辑之前)
result := s.handleModelRateLimit(&handleModelRateLimitParams{
ctx: ctx,
prefix: prefix,
account: account,
statusCode: statusCode,
body: body,
cache: s.cache,
groupID: groupID,
sessionHash: sessionHash,
isStickySession: isStickySession,
})
if result.Handled {
return result
}
// 503 仅处理模型限流(MODEL_CAPACITY_EXHAUSTED),非模型限流不做额外处理
// 避免将普通的 503 错误误判为账号问题
if statusCode == 503 {
return nil
}
// ========== 原有逻辑,保持不变 ==========
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != ""
// 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。
if logBody, maxBytes := s.getLogConfig(); logBody {
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
}
useScopeLimit := quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 解析失败:使用配置的 fallback 时间,直接限流整个账户
fallbackMinutes := 5
// 解析失败:使用默认限流时间(与临时限流保持一致)
// 可通过配置或环境变量覆盖
defaultDur := antigravityDefaultRateLimitDuration
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute
}
defaultDur := time.Duration(fallbackMinutes) * time.Minute
if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = fallbackDur
// 秒级环境变量优先级最高
if override, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = override
}
ra := time.Now().Add(defaultDur)
if useScopeLimit {
......@@ -2066,7 +2435,7 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return
return nil
}
resetTime := time.Unix(*resetAt, 0)
if useScopeLimit {
......@@ -2080,16 +2449,17 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return
return nil
}
// 其他错误码继续使用 rateLimitService
if s.rateLimitService == nil {
return
return nil
}
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
if shouldDisable {
log.Printf("%s status=%d marked_error", prefix, statusCode)
}
return nil
}
type antigravityStreamResult struct {
......@@ -2120,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
......@@ -2141,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2152,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -2277,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
......@@ -2305,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2316,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -2620,20 +2994,16 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int,
return fmt.Errorf("%s", message)
}
// WriteMappedClaudeError 导出版本,供 handler 层使用(如 fallback 错误处理)
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
}
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(body), maxBytes)
}
logBody, maxBytes := s.getLogConfig()
upstreamDetail := s.getUpstreamErrorDetail(body)
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
......@@ -2658,7 +3028,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
case 400:
statusCode = http.StatusBadRequest
errType = "invalid_request_error"
errMsg = "Invalid request"
errMsg = getPassthroughOrDefault(upstreamMsg, "Invalid request")
case 401:
statusCode = http.StatusBadGateway
errType = "authentication_error"
......@@ -2691,10 +3061,6 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN"
switch status {
......@@ -2728,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
var firstTokenMs *int
var last map[string]any
......@@ -2754,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2765,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -2908,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
......@@ -2940,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2951,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
......@@ -3121,8 +3491,8 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
return json.Marshal(payload)
}
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
// filterEmptyPartsFromGeminiRequest 过滤 parts 为空的消息
// Gemini API 不接受 parts,需要在请求前过滤
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
......
......@@ -7,7 +7,9 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
......@@ -113,7 +115,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-5",
"model": "claude-opus-4-6",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
......@@ -149,7 +151,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
},
}
result, err := svc.Forward(context.Background(), c, account, body)
result, err := svc.Forward(context.Background(), c, account, body, false)
require.Nil(t, result)
var promptErr *PromptTooLongError
......@@ -166,27 +168,261 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
require.Equal(t, "prompt_too_long", events[0].Kind)
}
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "4")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
// TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover
// 验证:当账号存在模型限流且剩余时间 >= antigravityRateLimitThreshold 时,
// Forward 方法应返回 UpstreamFailoverError,触发 Handler 切换账号
func TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 1,
Name: "acc-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求,ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
// TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover
// 验证:ForwardGemini 方法同样能正确将 AntigravityAccountSwitchError 转换为 UpstreamFailoverError
func TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 2,
Name: "acc-gemini-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求,ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
require.Equal(t, 4, got)
// TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling
// 验证:粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
require.Equal(t, 7, got)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]string{{"role": "user", "content": "hello"}},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 3,
Name: "acc-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.Forward(context.Background(), c, account, body, true)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "5")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
require.Equal(t, 5, got)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 4,
Name: "acc-gemini-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, true)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n"))
_, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n"))
}()
svc := &AntigravityGatewayService{}
start := time.Now().Add(-10 * time.Millisecond)
usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start)
_ = pr.Close()
require.NotNil(t, usage)
require.Equal(t, 1, usage.InputTokens)
// 第二次事件覆盖 output_tokens
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
require.Equal(t, 4, usage.CacheCreationInputTokens)
if firstTokenMs == nil {
t.Fatalf("expected firstTokenMs to be set")
}
// 确保有透传输出
require.True(t, strings.Contains(writer.Body.String(), "data:"))
}
......@@ -8,53 +8,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestIsAntigravityModelSupported(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// 直接支持的模型
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
// 可映射的模型
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// Claude 前缀兜底
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
{"Claude前缀 - claude-future-version", "claude-future-version", true},
// 不支持的模型
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false},
{"不支持 - mistral-7b", "mistral-7b", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsAntigravityModelSupported(tt.model)
require.Equal(t, tt.expected, got, "model: %s", tt.model)
})
}
}
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
svc := &AntigravityGatewayService{}
......@@ -64,7 +17,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
// 1. 账户级映射优先
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
......@@ -72,120 +25,124 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "custom-model",
},
{
name: "账户映射覆盖系统映射",
name: "账户映射 - 可覆盖默认映射的模型",
requestedModel: "claude-sonnet-4-5",
accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"},
expected: "my-custom-sonnet",
},
{
name: "账户映射 - 可覆盖未知模型",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 系统默认映射
// 2. 默认映射(DefaultAntigravityModelMapping)
{
name: "系统映射 - claude-3-5-sonnet-20241022",
requestedModel: "claude-3-5-sonnet-20241022",
name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-3-5-sonnet-20240620",
requestedModel: "claude-3-5-sonnet-20240620",
name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-opus-4",
requestedModel: "claude-opus-4",
name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-opus-4-5-20251101",
requestedModel: "claude-opus-4-5-20251101",
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4",
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
// 3. 默认映射中的透传(映射到自己)
{
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
requestedModel: "claude-3-haiku-20240307",
name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
name: "默认映射透传 - claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
name: "默认映射透传 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-thinking",
},
// 3. Gemini 2.5 → 3 映射
{
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
name: "默认映射透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "gemini-2.5-flash",
},
{
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
name: "默认映射透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-3-pro-high",
expected: "gemini-2.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
requestedModel: "gemini-future-model",
name: "默认映射透传 - gemini-3-flash",
requestedModel: "gemini-3-flash",
accountMapping: nil,
expected: "gemini-future-model",
expected: "gemini-3-flash",
},
// 4. 直接支持的模型
// 4. 未在默认映射中的模型返回空字符串(不支持)
{
name: "直接支持 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
name: "未知模型 - claude-unknown 返回空",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
{
name: "直接支持 - claude-opus-4-5-thinking",
requestedModel: "claude-opus-4-5-thinking",
name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "",
},
{
name: "直接支持 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
name: "未知模型 - claude-3-opus-20240229 返回空",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "",
},
// 5. 默认值 fallback(未知 claude 模型)
{
name: "默认值 - claude-unknown",
requestedModel: "claude-unknown",
name: "未知模型 - claude-opus-4 返回空",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
{
name: "默认值 - claude-3-opus-20240229",
requestedModel: "claude-3-opus-20240229",
name: "未知模型 - gemini-future-model 返回空",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
}
......@@ -219,12 +176,10 @@ func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
requestedModel string
expected string
}{
// 空字符串回退到默认值
{"空字符串", "", "claude-sonnet-4-5"},
// 非 claude/gemini 前缀回退到默认值
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
// 空字符串和非 claude/gemini 前缀返回空字符串
{"空字符串", "", ""},
{"非claude/gemini前缀 - gpt", "gpt-4", ""},
{"非claude/gemini前缀 - llama", "llama-3", ""},
}
for _, tt := range tests {
......@@ -248,10 +203,10 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射
{"可映射 - claude-opus-4", "claude-opus-4", true},
// 可映射(有明确前缀映射)
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
// 前缀透传
// 前缀透传(claude 和 gemini 前缀)
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
......@@ -267,3 +222,58 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
})
}
}
// TestMapAntigravityModel_WildcardTargetEqualsRequest 测试通配符映射目标恰好等于请求模型名的 edge case
// 例如 {"claude-*": "claude-sonnet-4-5"},请求 "claude-sonnet-4-5" 时应该通过
func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
tests := []struct {
name string
modelMapping map[string]any
requestedModel string
expected string
}{
{
name: "wildcard target equals request model",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "wildcard target differs from request model",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "claude-opus-4-6",
expected: "claude-sonnet-4-5",
},
{
name: "wildcard no match",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "gpt-4o",
expected: "",
},
{
name: "explicit passthrough same name",
modelMapping: map[string]any{"claude-sonnet-4-5": "claude-sonnet-4-5"},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "multiple wildcards target equals one request",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5", "gemini-*": "gemini-2.5-flash"},
requestedModel: "gemini-2.5-flash",
expected: "gemini-2.5-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": tt.modelMapping,
},
}
got := mapAntigravityModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "mapAntigravityModel(%q) = %q, want %q", tt.requestedModel, got, tt.expected)
})
}
}
package service
import (
"context"
"slices"
"strings"
"time"
......@@ -57,15 +58,20 @@ func normalizeAntigravityModelName(model string) string {
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
}
func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.isModelRateLimited(requestedModel) {
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
......@@ -132,3 +138,43 @@ func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
}
......@@ -21,6 +21,23 @@ type stubAntigravityUpstream struct {
calls []string
}
type recordingOKUpstream struct {
calls int
}
func (r *recordingOKUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
r.calls++
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, accountConcurrency)
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String()
s.calls = append(s.calls, url)
......@@ -53,10 +70,17 @@ type rateLimitCall struct {
resetAt time.Time
}
type modelRateLimitCall struct {
accountID int64
modelKey string // 存储的 key(应该是官方模型 ID,如 "claude-sonnet-4-5")
resetAt time.Time
}
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
......@@ -69,6 +93,11 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6
return nil
}
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error {
s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
......@@ -93,18 +122,21 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
}
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true
return nil
},
})
......@@ -123,14 +155,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true")
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
// 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
......@@ -140,20 +172,122 @@ func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "false")
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 1, Name: "acc-1", Platform: PlatformAntigravity}
// 429 + RATE_LIMIT_EXCEEDED + 模型名 → 模型限流
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "claude-sonnet-4-5", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流)
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 4, Name: "acc-4", Platform: PlatformAntigravity}
// 503 + 普通错误(非 MODEL_CAPACITY_EXHAUSTED)→ 不做任何处理
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
"message": "Service temporarily unavailable",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "SERVICE_UNAVAILABLE"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
// TestHandleUpstreamError_503_EmptyBody 测试 503 空响应体(不处理)
func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
account := &Account{ID: 5, Name: "acc-5", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
// 503 + 空响应体 → 不做任何处理
body := []byte(`{}`)
require.Len(t, repo.rateCalls, 1)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0]
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
require.Empty(t, repo.rateCalls)
}
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
......@@ -188,3 +322,771 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
}
func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
// Avoid flakiness around Unix second boundaries.
for {
now := time.Now()
if now.Nanosecond() < 800*1e6 {
break
}
time.Sleep(5 * time.Millisecond)
}
baseUnix := time.Now().Unix()
ts := ParseGeminiRateLimitResetTime(buildGeminiRateLimitBody("0.1s"))
require.NotNil(t, ts)
require.Equal(t, baseUnix+1, *ts, "fractional seconds should be rounded up to the next second")
}
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct {
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
}{
{
name: "valid complete response with RATE_LIMIT_EXCEEDED",
body: `{
"error": {
"code": 429,
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"domain": "cloudcode-pa.googleapis.com",
"metadata": {
"model": "claude-sonnet-4-5",
"quotaResetDelay": "201.506475ms"
},
"reason": "RATE_LIMIT_EXCEEDED"
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "0.201506475s"
}
],
"message": "You have exhausted your capacity on this model.",
"status": "RESOURCE_EXHAUSTED"
}
}`,
expectedDelay: 201506475 * time.Nanosecond,
expectedModel: "claude-sonnet-4-5",
},
{
name: "429 RESOURCE_EXHAUSTED without RATE_LIMIT_EXCEEDED - should return nil",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"metadata": {"model": "claude-sonnet-4-5"},
"reason": "QUOTA_EXCEEDED"
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "3s"
}
]
}
}`,
expectedNil: true,
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay",
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`,
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "SERVICE_UNAVAILABLE"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"}
]
}
}`,
expectedNil: true,
},
{
name: "wrong status - should return nil",
body: `{
"error": {
"code": 429,
"status": "INVALID_ARGUMENT",
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "missing status - should return nil",
body: `{
"error": {
"code": 429,
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "milliseconds format is now supported",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test-model"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "500ms"}
]
}
}`,
expectedDelay: 500 * time.Millisecond,
expectedModel: "test-model",
},
{
name: "minutes format is supported",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "4m50s"}
]
}
}`,
expectedDelay: 4*time.Minute + 50*time.Second,
expectedModel: "gemini-3-pro",
},
{
name: "missing model name - should return nil",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "invalid JSON",
body: `not json`,
expectedNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseAntigravitySmartRetryInfo([]byte(tt.body))
if tt.expectedNil {
if result != nil {
t.Errorf("expected nil, got %+v", result)
}
return
}
if result == nil {
t.Errorf("expected non-nil result")
return
}
if result.RetryDelay != tt.expectedDelay {
t.Errorf("RetryDelay = %v, want %v", result.RetryDelay, tt.expectedDelay)
}
if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
}
})
}
}
func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
oauthAccount := &Account{Type: AccountTypeOAuth, Platform: PlatformAntigravity}
setupTokenAccount := &Account{Type: AccountTypeSetupToken, Platform: PlatformAntigravity}
upstreamAccount := &Account{Type: AccountTypeUpstream, Platform: PlatformAntigravity}
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct {
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
minWait time.Duration
modelName string
}{
{
name: "OAuth account with short delay (< 7s) - smart retry",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 1 * time.Second, // 0.5s < 1s, 使用最小等待时间 1s
modelName: "claude-opus-4",
},
{
name: "SetupToken account with short delay - smart retry",
account: setupTokenAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 3 * time.Second,
modelName: "gemini-3-flash",
},
{
name: "OAuth account with long delay (>= 7s) - direct rate limit",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "claude-sonnet-4-5",
},
{
name: "Upstream account with short delay - smart retry",
account: upstreamAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "2s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 2 * time.Second,
modelName: "claude-sonnet-4-5",
},
{
name: "API Key account - should not trigger",
account: apiKeyAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: false,
},
{
name: "OAuth account with exactly 7s delay - direct rate limit",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-pro",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay",
account: oauthAccount,
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-2.5-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}
],
"message": "No capacity available for model gemini-2.5-flash on the server"
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-2.5-flash",
},
{
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}
],
"message": "You have exhausted your capacity on this model."
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
}
if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
}
if shouldRetry {
if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
})
}
}
// TestSetModelRateLimitByModelName_UsesOfficialModelID 验证写入端使用官方模型 ID
func TestSetModelRateLimitByModelName_UsesOfficialModelID(t *testing.T) {
tests := []struct {
name string
modelName string
expectedModelKey string
expectedSuccess bool
}{
{
name: "claude-sonnet-4-5 should be stored as-is",
modelName: "claude-sonnet-4-5",
expectedModelKey: "claude-sonnet-4-5",
expectedSuccess: true,
},
{
name: "gemini-3-pro-high should be stored as-is",
modelName: "gemini-3-pro-high",
expectedModelKey: "gemini-3-pro-high",
expectedSuccess: true,
},
{
name: "gemini-3-flash should be stored as-is",
modelName: "gemini-3-flash",
expectedModelKey: "gemini-3-flash",
expectedSuccess: true,
},
{
name: "empty model name should fail",
modelName: "",
expectedModelKey: "",
expectedSuccess: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
resetAt := time.Now().Add(30 * time.Second)
success := setModelRateLimitByModelName(
context.Background(),
repo,
123, // accountID
tt.modelName,
"[test]",
429,
resetAt,
false, // afterSmartRetry
)
require.Equal(t, tt.expectedSuccess, success)
if tt.expectedSuccess {
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
require.Equal(t, int64(123), call.accountID)
// 关键断言:存储的 key 应该是官方模型 ID,而不是 scope
require.Equal(t, tt.expectedModelKey, call.modelKey, "should store official model ID, not scope")
require.WithinDuration(t, resetAt, call.resetAt, time.Second)
} else {
require.Empty(t, repo.modelRateLimitCalls)
}
})
}
}
// TestSetModelRateLimitByModelName_NotConvertToScope 验证不会将模型名转换为 scope
func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
resetAt := time.Now().Add(30 * time.Second)
// 调用 setModelRateLimitByModelName,传入官方模型 ID
success := setModelRateLimitByModelName(
context.Background(),
repo,
456,
"claude-sonnet-4-5", // 官方模型 ID
"[test]",
429,
resetAt,
true, // afterSmartRetry
)
require.True(t, success)
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
// 关键断言:存储的应该是 "claude-sonnet-4-5",而不是 "claude_sonnet"
require.Equal(t, "claude-sonnet-4-5", call.modelKey, "should NOT convert to scope like claude_sonnet")
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
}
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
}
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 2,
Name: "acc-2",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(11 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
}
func TestIsAntigravityAccountSwitchError(t *testing.T) {
tests := []struct {
name string
err error
expectedOK bool
expectedID int64
expectedModel string
}{
{
name: "nil error",
err: nil,
expectedOK: false,
},
{
name: "generic error",
err: fmt.Errorf("some error"),
expectedOK: false,
},
{
name: "account switch error",
err: &AntigravityAccountSwitchError{
OriginalAccountID: 123,
RateLimitedModel: "claude-sonnet-4-5",
IsStickySession: true,
},
expectedOK: true,
expectedID: 123,
expectedModel: "claude-sonnet-4-5",
},
{
name: "wrapped account switch error",
err: fmt.Errorf("wrapped: %w", &AntigravityAccountSwitchError{
OriginalAccountID: 456,
RateLimitedModel: "gemini-3-flash",
IsStickySession: false,
}),
expectedOK: true,
expectedID: 456,
expectedModel: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switchErr, ok := IsAntigravityAccountSwitchError(tt.err)
require.Equal(t, tt.expectedOK, ok)
if tt.expectedOK {
require.NotNil(t, switchErr)
require.Equal(t, tt.expectedID, switchErr.OriginalAccountID)
require.Equal(t, tt.expectedModel, switchErr.RateLimitedModel)
} else {
require.Nil(t, switchErr)
}
})
}
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{
OriginalAccountID: 789,
RateLimitedModel: "claude-opus-4-5",
IsStickySession: true,
}
msg := err.Error()
require.Contains(t, msg, "789")
require.Contains(t, msg, "claude-opus-4-5")
}
// stubSchedulerCache 用于测试的 SchedulerCache 实现
type stubSchedulerCache struct {
SchedulerCache
setAccountCalls []*Account
setAccountErr error
}
func (s *stubSchedulerCache) SetAccount(ctx context.Context, account *Account) error {
s.setAccountCalls = append(s.setAccountCalls, account)
return s.setAccountErr
}
// TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache 测试模型限流后更新缓存
func TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache(t *testing.T) {
cache := &stubSchedulerCache{}
snapshotService := &SchedulerSnapshotService{cache: cache}
svc := &AntigravityGatewayService{
schedulerSnapshot: snapshotService,
}
account := &Account{
ID: 100,
Name: "test-account",
Platform: PlatformAntigravity,
}
modelKey := "claude-sonnet-4-5"
resetAt := time.Now().Add(30 * time.Second)
svc.updateAccountModelRateLimitInCache(context.Background(), account, modelKey, resetAt)
// 验证 Extra 字段被正确更新
require.NotNil(t, account.Extra)
limits, ok := account.Extra["model_rate_limits"].(map[string]any)
require.True(t, ok)
modelLimit, ok := limits[modelKey].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, modelLimit["rate_limited_at"])
require.NotEmpty(t, modelLimit["rate_limit_reset_at"])
// 验证 cache.SetAccount 被调用
require.Len(t, cache.setAccountCalls, 1)
require.Equal(t, account.ID, cache.setAccountCalls[0].ID)
}
// TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot 测试 schedulerSnapshot 为 nil 时不 panic
func TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot(t *testing.T) {
svc := &AntigravityGatewayService{
schedulerSnapshot: nil,
}
account := &Account{ID: 1, Name: "test"}
// 不应 panic
svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second))
// Extra 不应被更新(因为函数提前返回)
require.Nil(t, account.Extra)
}
// TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra 测试保留已有的 Extra 数据
func TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra(t *testing.T) {
cache := &stubSchedulerCache{}
snapshotService := &SchedulerSnapshotService{cache: cache}
svc := &AntigravityGatewayService{
schedulerSnapshot: snapshotService,
}
account := &Account{
ID: 200,
Name: "test-account",
Platform: PlatformAntigravity,
Extra: map[string]any{
"existing_key": "existing_value",
"model_rate_limits": map[string]any{
"gemini-3-flash": map[string]any{
"rate_limited_at": "2024-01-01T00:00:00Z",
"rate_limit_reset_at": "2024-01-01T00:05:00Z",
},
},
},
}
svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second))
// 验证已有数据被保留
require.Equal(t, "existing_value", account.Extra["existing_key"])
limits := account.Extra["model_rate_limits"].(map[string]any)
require.NotNil(t, limits["gemini-3-flash"])
require.NotNil(t, limits["claude-sonnet-4-5"])
}
// TestSchedulerSnapshotService_UpdateAccountInCache 测试 UpdateAccountInCache 方法
func TestSchedulerSnapshotService_UpdateAccountInCache(t *testing.T) {
t.Run("calls cache.SetAccount", func(t *testing.T) {
cache := &stubSchedulerCache{}
svc := &SchedulerSnapshotService{cache: cache}
account := &Account{ID: 123, Name: "test"}
err := svc.UpdateAccountInCache(context.Background(), account)
require.NoError(t, err)
require.Len(t, cache.setAccountCalls, 1)
require.Equal(t, int64(123), cache.setAccountCalls[0].ID)
})
t.Run("returns nil when cache is nil", func(t *testing.T) {
svc := &SchedulerSnapshotService{cache: nil}
err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1})
require.NoError(t, err)
})
t.Run("returns nil when account is nil", func(t *testing.T) {
cache := &stubSchedulerCache{}
svc := &SchedulerSnapshotService{cache: cache}
err := svc.UpdateAccountInCache(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, cache.setAccountCalls)
})
t.Run("propagates cache error", func(t *testing.T) {
expectedErr := fmt.Errorf("cache error")
cache := &stubSchedulerCache{setAccountErr: expectedErr}
svc := &SchedulerSnapshotService{cache: cache}
err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1})
require.ErrorIs(t, err, expectedErr)
})
}
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
errors []error
callIdx int
calls []string
}
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
idx := m.callIdx
m.calls = append(m.calls, req.URL.String())
m.callIdx++
if idx < len(m.responses) {
return m.responses[idx], m.errors[idx]
}
return nil, nil
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return m.Do(req, proxyURL, accountID, accountConcurrency)
}
// TestHandleSmartRetry_URLLevelRateLimit 测试 URL 级别限流切换
func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test", "https://ag-2.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinueURL, result.action)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_LongDelay_ReturnsSwitchError 测试 retryDelay >= 阈值时返回 switchError
func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 15s >= 7s 阈值,应该返回 switchError
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for long delay")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_ShortDelay_SmartRetrySuccess 测试智能重试成功
func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.5s < 7s 阈值,应该触发智能重试
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.err)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 1, "should have made one retry call")
}
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
// 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次)
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp1 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp2 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp3 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp1, failResp2, failResp3},
errors: []error{nil, nil, nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 2,
Name: "acc-2",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError after smart retry failed")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-flash", result.switchError.RateLimitedModel)
require.False(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing.T) {
account := &Account{
ID: 4,
Name: "acc-4",
Type: AccountTypeAPIKey, // 非 Antigravity 平台账号
Platform: PlatformAnthropic,
}
// 即使是模型限流响应,非 OAuth 账号也应该走默认逻辑
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinue, result.action, "non-Antigravity platform account should continue default logic")
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic 测试非模型限流响应走默认逻辑
func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) {
account := &Account{
ID: 5,
Name: "acc-5",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 但没有 RATE_LIMIT_EXCEEDED 或 MODEL_CAPACITY_EXHAUSTED
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"}
],
"message": "Quota exceeded"
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinue, result.action, "non-model rate limit should continue default logic")
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError 测试刚好等于阈值时返回 switchError
func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 6,
Name: "acc-6",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 刚好 7s = 7s 阈值,应该返回 switchError
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.NotNil(t, result.switchError, "exactly at threshold should return switchError")
require.Equal(t, "gemini-pro", result.switchError.RateLimitedModel)
}
// TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates 测试 switchError 正确传播到上层
func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing.T) {
// 模拟 429 + 长延迟的响应
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"}
]
}
}`)
rateLimitResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{rateLimitResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 7,
Name: "acc-7",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result when switchError")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
// 第一次网络错误,第二次成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误)
errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发
}
account := &Account{
ID: 8,
Name: "acc-8",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.1s < 7s 阈值,应该触发智能重试
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 2, "should have made two retry calls")
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 9,
Name: "acc-9",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 + RATE_LIMIT_EXCEEDED + 无 retryDelay → 使用默认 1 分钟限流
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}
],
"message": "You have exhausted your capacity on this model."
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.NotNil(t, result.switchError, "should return switchError for no retryDelay")
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
//go:build unit
package service
import (
"testing"
)
func TestApplyThinkingModelSuffix(t *testing.T) {
tests := []struct {
name string
mappedModel string
thinkingEnabled bool
expected string
}{
// Thinking 未开启:保持原样
{
name: "thinking disabled - claude-sonnet-4-5 unchanged",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: "claude-sonnet-4-5",
},
{
name: "thinking disabled - other model unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: false,
expected: "claude-opus-4-6-thinking",
},
// Thinking 开启 + claude-sonnet-4-5:自动添加后缀
{
name: "thinking enabled - claude-sonnet-4-5 becomes thinking version",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
// Thinking 开启 + 其他模型:保持原样
{
name: "thinking enabled - claude-sonnet-4-5-thinking unchanged",
mappedModel: "claude-sonnet-4-5-thinking",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
{
name: "thinking enabled - claude-opus-4-6-thinking unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: true,
expected: "claude-opus-4-6-thinking",
},
{
name: "thinking enabled - gemini model unchanged",
mappedModel: "gemini-3-flash",
thinkingEnabled: true,
expected: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled)
if result != tt.expected {
t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q",
tt.mappedModel, tt.thinkingEnabled, result, tt.expected)
}
})
}
}
......@@ -42,7 +42,18 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", errors.New("upstream account missing api_key in credentials")
}
return apiKey, nil
}
if account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestAntigravityTokenProvider_GetAccessToken_Upstream(t *testing.T) {
provider := &AntigravityTokenProvider{}
t.Run("upstream account with valid api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{
"api_key": "sk-test-key-12345",
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sk-test-key-12345", token)
})
t.Run("upstream account missing api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
t.Run("upstream account with empty api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{
"api_key": "",
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
t.Run("upstream account with nil credentials", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
}
func TestAntigravityTokenProvider_GetAccessToken_Guards(t *testing.T) {
provider := &AntigravityTokenProvider{}
t.Run("nil account", func(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
})
t.Run("non-antigravity platform", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an antigravity account")
require.Empty(t, token)
})
t.Run("unsupported account type", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an antigravity oauth account")
require.Empty(t, token)
})
}
......@@ -6,8 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"math/rand/v2"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
......@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
......@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
randVal := rand.Float64()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
......
......@@ -56,7 +56,8 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator {
//
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
// Step 3: 对于 messages 路径,进行严格验证:
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
// Step 4: 对于 messages 路径,进行严格验证:
// - System prompt 相似度检查
// - X-App header 检查
// - anthropic-beta header 检查
......@@ -75,14 +76,20 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return true
}
// Step 3: messages 路径,进行严格验证
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
}
// Step 4: messages 路径,进行严格验证
// 3.1 检查 system prompt 相似度
// 4.1 检查 system prompt 相似度
if !v.hasClaudeCodeSystemPrompt(body) {
return false
}
// 3.2 检查必需的 headers(值不为空即可)
// 4.2 检查必需的 headers(值不为空即可)
xApp := r.Header.Get("X-App")
if xApp == "" {
return false
......@@ -98,7 +105,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false
}
// 3.3 验证 metadata.user_id
// 4.3 验证 metadata.user_id
if body == nil {
return false
}
......
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestClaudeCodeValidator_ProbeBypass(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
ok := validator.Validate(req, map[string]any{
"model": "claude-haiku-4-5",
"max_tokens": 1,
})
require.True(t, ok)
}
func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
req.Header.Set("User-Agent", "curl/8.0.0")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
ok := validator.Validate(req, map[string]any{
"model": "claude-haiku-4-5",
"max_tokens": 1,
})
require.False(t, ok)
}
func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
ok := validator.Validate(req, map[string]any{
"model": "claude-haiku-4-5",
"max_tokens": 1,
})
require.False(t, ok)
}
func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil)
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
ok := validator.Validate(req, nil)
require.True(t, ok)
}
......@@ -35,6 +35,7 @@ type ConcurrencyCache interface {
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
......@@ -77,6 +78,11 @@ type AccountWithConcurrency struct {
MaxConcurrency int
}
type UserWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
......@@ -84,6 +90,13 @@ type AccountLoadInfo struct {
LoadRate int // 0-100+ (percent)
}
type UserLoadInfo struct {
UserID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
......@@ -253,6 +266,14 @@ func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// GetUsersLoadBatch returns load info for multiple users.
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
if s.cache == nil {
return map[int64]*UserLoadInfo{}, nil
}
return s.cache.GetUsersLoadBatch(ctx, users)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
......
......@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
......
package service
import "github.com/gin-gonic/gin"
const errorPassthroughServiceContextKey = "error_passthrough_service"
// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。
func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) {
if c == nil || svc == nil {
return
}
c.Set(errorPassthroughServiceContextKey, svc)
}
func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService {
if c == nil {
return nil
}
v, ok := c.Get(errorPassthroughServiceContextKey)
if !ok {
return nil
}
svc, ok := v.(*ErrorPassthroughService)
if !ok {
return nil
}
return svc
}
// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。
func applyErrorPassthroughRule(
c *gin.Context,
platform string,
upstreamStatus int,
responseBody []byte,
defaultStatus int,
defaultErrType string,
defaultErrMsg string,
) (status int, errType string, errMsg string, matched bool) {
status = defaultStatus
errType = defaultErrType
errMsg = defaultErrMsg
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return status, errType, errMsg, false
}
rule := svc.MatchRule(platform, upstreamStatus, responseBody)
if rule == nil {
return status, errType, errMsg, false
}
status = upstreamStatus
if !rule.PassthroughCode && rule.ResponseCode != nil {
status = *rule.ResponseCode
}
errMsg = ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
errMsg = *rule.CustomMessage
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error"
return status, errType, errMsg, true
}
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusUnprocessableEntity,
[]byte(`{"error":{"message":"invalid schema"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.False(t, matched)
assert.Equal(t, http.StatusBadGateway, status)
assert.Equal(t, "upstream_error", errType)
assert.Equal(t, "Upstream request failed", errMsg)
}
func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &OpenAIGatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GeminiMessagesCompatService{}
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey}
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody)
require.Error(t, err)
assert.Equal(t, http.StatusBadRequest, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "invalid_request_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &GatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "上游请求失败", errField["message"])
}
func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &OpenAIGatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "OpenAI上游失败", errField["message"])
}
func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &GeminiMessagesCompatService{}
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey}
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Gemini上游失败", errField["message"])
}
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
return &model.ErrorPassthroughRule{
ID: 1,
Name: "non-failover-rule",
Enabled: true,
Priority: 1,
ErrorCodes: []int{statusCode},
Keywords: []string{keyword},
MatchMode: model.MatchModeAll,
PassthroughCode: false,
ResponseCode: &respCode,
PassthroughBody: false,
CustomMessage: &customMessage,
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment