Unverified Commit c4615a12 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #509 from touwaeriol/pr/antigravity-full

feat(antigravity): comprehensive enhancements - model mapping, rate limiting, scheduling & ops
parents 5d4327eb fa28dcbf
......@@ -8,53 +8,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestIsAntigravityModelSupported(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// 直接支持的模型
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
// 可映射的模型
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// Claude 前缀兜底
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
{"Claude前缀 - claude-future-version", "claude-future-version", true},
// 不支持的模型
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false},
{"不支持 - mistral-7b", "mistral-7b", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsAntigravityModelSupported(tt.model)
require.Equal(t, tt.expected, got, "model: %s", tt.model)
})
}
}
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
svc := &AntigravityGatewayService{}
......@@ -64,7 +17,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
// 1. 账户级映射优先
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
......@@ -72,120 +25,124 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "custom-model",
},
{
name: "账户映射覆盖系统映射",
name: "账户映射 - 可覆盖默认映射的模型",
requestedModel: "claude-sonnet-4-5",
accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"},
expected: "my-custom-sonnet",
},
{
name: "账户映射 - 可覆盖未知模型",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 系统默认映射
// 2. 默认映射(DefaultAntigravityModelMapping)
{
name: "系统映射 - claude-3-5-sonnet-20241022",
requestedModel: "claude-3-5-sonnet-20241022",
name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-3-5-sonnet-20240620",
requestedModel: "claude-3-5-sonnet-20240620",
name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-opus-4",
requestedModel: "claude-opus-4",
name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-opus-4-5-20251101",
requestedModel: "claude-opus-4-5-20251101",
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4",
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
// 3. 默认映射中的透传(映射到自己)
{
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
requestedModel: "claude-3-haiku-20240307",
name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
name: "默认映射透传 - claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
name: "默认映射透传 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-thinking",
},
// 3. Gemini 2.5 → 3 映射
{
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
name: "默认映射透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "gemini-2.5-flash",
},
{
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
name: "默认映射透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-3-pro-high",
expected: "gemini-2.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
requestedModel: "gemini-future-model",
name: "默认映射透传 - gemini-3-flash",
requestedModel: "gemini-3-flash",
accountMapping: nil,
expected: "gemini-future-model",
expected: "gemini-3-flash",
},
// 4. 直接支持的模型
// 4. 未在默认映射中的模型返回空字符串(不支持)
{
name: "直接支持 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
name: "未知模型 - claude-unknown 返回空",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
{
name: "直接支持 - claude-opus-4-5-thinking",
requestedModel: "claude-opus-4-5-thinking",
name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "",
},
{
name: "直接支持 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
name: "未知模型 - claude-3-opus-20240229 返回空",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "",
},
// 5. 默认值 fallback(未知 claude 模型)
{
name: "默认值 - claude-unknown",
requestedModel: "claude-unknown",
name: "未知模型 - claude-opus-4 返回空",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
{
name: "默认值 - claude-3-opus-20240229",
requestedModel: "claude-3-opus-20240229",
name: "未知模型 - gemini-future-model 返回空",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
}
......@@ -219,12 +176,10 @@ func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
requestedModel string
expected string
}{
// 空字符串回退到默认值
{"空字符串", "", "claude-sonnet-4-5"},
// 非 claude/gemini 前缀回退到默认值
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
// 空字符串和非 claude/gemini 前缀返回空字符串
{"空字符串", "", ""},
{"非claude/gemini前缀 - gpt", "gpt-4", ""},
{"非claude/gemini前缀 - llama", "llama-3", ""},
}
for _, tt := range tests {
......@@ -248,10 +203,10 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射
{"可映射 - claude-opus-4", "claude-opus-4", true},
// 可映射(有明确前缀映射)
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
// 前缀透传
// 前缀透传(claude 和 gemini 前缀)
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
......@@ -267,3 +222,58 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
})
}
}
// TestMapAntigravityModel_WildcardTargetEqualsRequest 测试通配符映射目标恰好等于请求模型名的 edge case
// 例如 {"claude-*": "claude-sonnet-4-5"},请求 "claude-sonnet-4-5" 时应该通过
func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
tests := []struct {
name string
modelMapping map[string]any
requestedModel string
expected string
}{
{
name: "wildcard target equals request model",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "wildcard target differs from request model",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "claude-opus-4-6",
expected: "claude-sonnet-4-5",
},
{
name: "wildcard no match",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "gpt-4o",
expected: "",
},
{
name: "explicit passthrough same name",
modelMapping: map[string]any{"claude-sonnet-4-5": "claude-sonnet-4-5"},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "multiple wildcards target equals one request",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5", "gemini-*": "gemini-2.5-flash"},
requestedModel: "gemini-2.5-flash",
expected: "gemini-2.5-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": tt.modelMapping,
},
}
got := mapAntigravityModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "mapAntigravityModel(%q) = %q, want %q", tt.requestedModel, got, tt.expected)
})
}
}
package service
import (
"context"
"slices"
"strings"
"time"
......@@ -57,15 +58,20 @@ func normalizeAntigravityModelName(model string) string {
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
}
func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.isModelRateLimited(requestedModel) {
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
......@@ -132,3 +138,43 @@ func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
}
......@@ -21,6 +21,23 @@ type stubAntigravityUpstream struct {
calls []string
}
type recordingOKUpstream struct {
calls int
}
func (r *recordingOKUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
r.calls++
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, accountConcurrency)
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String()
s.calls = append(s.calls, url)
......@@ -53,10 +70,17 @@ type rateLimitCall struct {
resetAt time.Time
}
type modelRateLimitCall struct {
accountID int64
modelKey string // 存储的 key(应该是官方模型 ID,如 "claude-sonnet-4-5")
resetAt time.Time
}
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
......@@ -69,6 +93,11 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6
return nil
}
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error {
s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
......@@ -93,18 +122,21 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
}
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true
return nil
},
})
......@@ -123,14 +155,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true")
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
// 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
......@@ -140,20 +172,122 @@ func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "false")
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 1, Name: "acc-1", Platform: PlatformAntigravity}
// 429 + RATE_LIMIT_EXCEEDED + 模型名 → 模型限流
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "claude-sonnet-4-5", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流)
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 4, Name: "acc-4", Platform: PlatformAntigravity}
// 503 + 普通错误(非 MODEL_CAPACITY_EXHAUSTED)→ 不做任何处理
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
"message": "Service temporarily unavailable",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "SERVICE_UNAVAILABLE"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
// TestHandleUpstreamError_503_EmptyBody 测试 503 空响应体(不处理)
func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
account := &Account{ID: 5, Name: "acc-5", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
// 503 + 空响应体 → 不做任何处理
body := []byte(`{}`)
require.Len(t, repo.rateCalls, 1)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0]
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
require.Empty(t, repo.rateCalls)
}
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
......@@ -188,3 +322,771 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
}
func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
// Avoid flakiness around Unix second boundaries.
for {
now := time.Now()
if now.Nanosecond() < 800*1e6 {
break
}
time.Sleep(5 * time.Millisecond)
}
baseUnix := time.Now().Unix()
ts := ParseGeminiRateLimitResetTime(buildGeminiRateLimitBody("0.1s"))
require.NotNil(t, ts)
require.Equal(t, baseUnix+1, *ts, "fractional seconds should be rounded up to the next second")
}
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct {
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
}{
{
name: "valid complete response with RATE_LIMIT_EXCEEDED",
body: `{
"error": {
"code": 429,
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"domain": "cloudcode-pa.googleapis.com",
"metadata": {
"model": "claude-sonnet-4-5",
"quotaResetDelay": "201.506475ms"
},
"reason": "RATE_LIMIT_EXCEEDED"
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "0.201506475s"
}
],
"message": "You have exhausted your capacity on this model.",
"status": "RESOURCE_EXHAUSTED"
}
}`,
expectedDelay: 201506475 * time.Nanosecond,
expectedModel: "claude-sonnet-4-5",
},
{
name: "429 RESOURCE_EXHAUSTED without RATE_LIMIT_EXCEEDED - should return nil",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"metadata": {"model": "claude-sonnet-4-5"},
"reason": "QUOTA_EXCEEDED"
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "3s"
}
]
}
}`,
expectedNil: true,
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay",
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`,
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "SERVICE_UNAVAILABLE"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"}
]
}
}`,
expectedNil: true,
},
{
name: "wrong status - should return nil",
body: `{
"error": {
"code": 429,
"status": "INVALID_ARGUMENT",
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "missing status - should return nil",
body: `{
"error": {
"code": 429,
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "milliseconds format is now supported",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test-model"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "500ms"}
]
}
}`,
expectedDelay: 500 * time.Millisecond,
expectedModel: "test-model",
},
{
name: "minutes format is supported",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "4m50s"}
]
}
}`,
expectedDelay: 4*time.Minute + 50*time.Second,
expectedModel: "gemini-3-pro",
},
{
name: "missing model name - should return nil",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "invalid JSON",
body: `not json`,
expectedNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseAntigravitySmartRetryInfo([]byte(tt.body))
if tt.expectedNil {
if result != nil {
t.Errorf("expected nil, got %+v", result)
}
return
}
if result == nil {
t.Errorf("expected non-nil result")
return
}
if result.RetryDelay != tt.expectedDelay {
t.Errorf("RetryDelay = %v, want %v", result.RetryDelay, tt.expectedDelay)
}
if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
}
})
}
}
func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
oauthAccount := &Account{Type: AccountTypeOAuth, Platform: PlatformAntigravity}
setupTokenAccount := &Account{Type: AccountTypeSetupToken, Platform: PlatformAntigravity}
upstreamAccount := &Account{Type: AccountTypeUpstream, Platform: PlatformAntigravity}
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct {
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
minWait time.Duration
modelName string
}{
{
name: "OAuth account with short delay (< 7s) - smart retry",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 1 * time.Second, // 0.5s < 1s, 使用最小等待时间 1s
modelName: "claude-opus-4",
},
{
name: "SetupToken account with short delay - smart retry",
account: setupTokenAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 3 * time.Second,
modelName: "gemini-3-flash",
},
{
name: "OAuth account with long delay (>= 7s) - direct rate limit",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "claude-sonnet-4-5",
},
{
name: "Upstream account with short delay - smart retry",
account: upstreamAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "2s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 2 * time.Second,
modelName: "claude-sonnet-4-5",
},
{
name: "API Key account - should not trigger",
account: apiKeyAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: false,
},
{
name: "OAuth account with exactly 7s delay - direct rate limit",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-pro",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay",
account: oauthAccount,
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-2.5-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}
],
"message": "No capacity available for model gemini-2.5-flash on the server"
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-2.5-flash",
},
{
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}
],
"message": "You have exhausted your capacity on this model."
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
}
if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
}
if shouldRetry {
if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
})
}
}
// TestSetModelRateLimitByModelName_UsesOfficialModelID 验证写入端使用官方模型 ID
func TestSetModelRateLimitByModelName_UsesOfficialModelID(t *testing.T) {
tests := []struct {
name string
modelName string
expectedModelKey string
expectedSuccess bool
}{
{
name: "claude-sonnet-4-5 should be stored as-is",
modelName: "claude-sonnet-4-5",
expectedModelKey: "claude-sonnet-4-5",
expectedSuccess: true,
},
{
name: "gemini-3-pro-high should be stored as-is",
modelName: "gemini-3-pro-high",
expectedModelKey: "gemini-3-pro-high",
expectedSuccess: true,
},
{
name: "gemini-3-flash should be stored as-is",
modelName: "gemini-3-flash",
expectedModelKey: "gemini-3-flash",
expectedSuccess: true,
},
{
name: "empty model name should fail",
modelName: "",
expectedModelKey: "",
expectedSuccess: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
resetAt := time.Now().Add(30 * time.Second)
success := setModelRateLimitByModelName(
context.Background(),
repo,
123, // accountID
tt.modelName,
"[test]",
429,
resetAt,
false, // afterSmartRetry
)
require.Equal(t, tt.expectedSuccess, success)
if tt.expectedSuccess {
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
require.Equal(t, int64(123), call.accountID)
// 关键断言:存储的 key 应该是官方模型 ID,而不是 scope
require.Equal(t, tt.expectedModelKey, call.modelKey, "should store official model ID, not scope")
require.WithinDuration(t, resetAt, call.resetAt, time.Second)
} else {
require.Empty(t, repo.modelRateLimitCalls)
}
})
}
}
// TestSetModelRateLimitByModelName_NotConvertToScope 验证不会将模型名转换为 scope
func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
resetAt := time.Now().Add(30 * time.Second)
// 调用 setModelRateLimitByModelName,传入官方模型 ID
success := setModelRateLimitByModelName(
context.Background(),
repo,
456,
"claude-sonnet-4-5", // 官方模型 ID
"[test]",
429,
resetAt,
true, // afterSmartRetry
)
require.True(t, success)
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
// 关键断言:存储的应该是 "claude-sonnet-4-5",而不是 "claude_sonnet"
require.Equal(t, "claude-sonnet-4-5", call.modelKey, "should NOT convert to scope like claude_sonnet")
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
}
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
}
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 2,
Name: "acc-2",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(11 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
}
func TestIsAntigravityAccountSwitchError(t *testing.T) {
tests := []struct {
name string
err error
expectedOK bool
expectedID int64
expectedModel string
}{
{
name: "nil error",
err: nil,
expectedOK: false,
},
{
name: "generic error",
err: fmt.Errorf("some error"),
expectedOK: false,
},
{
name: "account switch error",
err: &AntigravityAccountSwitchError{
OriginalAccountID: 123,
RateLimitedModel: "claude-sonnet-4-5",
IsStickySession: true,
},
expectedOK: true,
expectedID: 123,
expectedModel: "claude-sonnet-4-5",
},
{
name: "wrapped account switch error",
err: fmt.Errorf("wrapped: %w", &AntigravityAccountSwitchError{
OriginalAccountID: 456,
RateLimitedModel: "gemini-3-flash",
IsStickySession: false,
}),
expectedOK: true,
expectedID: 456,
expectedModel: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switchErr, ok := IsAntigravityAccountSwitchError(tt.err)
require.Equal(t, tt.expectedOK, ok)
if tt.expectedOK {
require.NotNil(t, switchErr)
require.Equal(t, tt.expectedID, switchErr.OriginalAccountID)
require.Equal(t, tt.expectedModel, switchErr.RateLimitedModel)
} else {
require.Nil(t, switchErr)
}
})
}
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{
OriginalAccountID: 789,
RateLimitedModel: "claude-opus-4-5",
IsStickySession: true,
}
msg := err.Error()
require.Contains(t, msg, "789")
require.Contains(t, msg, "claude-opus-4-5")
}
// stubSchedulerCache 用于测试的 SchedulerCache 实现
type stubSchedulerCache struct {
SchedulerCache
setAccountCalls []*Account
setAccountErr error
}
func (s *stubSchedulerCache) SetAccount(ctx context.Context, account *Account) error {
s.setAccountCalls = append(s.setAccountCalls, account)
return s.setAccountErr
}
// TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache 测试模型限流后更新缓存
func TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache(t *testing.T) {
cache := &stubSchedulerCache{}
snapshotService := &SchedulerSnapshotService{cache: cache}
svc := &AntigravityGatewayService{
schedulerSnapshot: snapshotService,
}
account := &Account{
ID: 100,
Name: "test-account",
Platform: PlatformAntigravity,
}
modelKey := "claude-sonnet-4-5"
resetAt := time.Now().Add(30 * time.Second)
svc.updateAccountModelRateLimitInCache(context.Background(), account, modelKey, resetAt)
// 验证 Extra 字段被正确更新
require.NotNil(t, account.Extra)
limits, ok := account.Extra["model_rate_limits"].(map[string]any)
require.True(t, ok)
modelLimit, ok := limits[modelKey].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, modelLimit["rate_limited_at"])
require.NotEmpty(t, modelLimit["rate_limit_reset_at"])
// 验证 cache.SetAccount 被调用
require.Len(t, cache.setAccountCalls, 1)
require.Equal(t, account.ID, cache.setAccountCalls[0].ID)
}
// TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot 测试 schedulerSnapshot 为 nil 时不 panic
func TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot(t *testing.T) {
svc := &AntigravityGatewayService{
schedulerSnapshot: nil,
}
account := &Account{ID: 1, Name: "test"}
// 不应 panic
svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second))
// Extra 不应被更新(因为函数提前返回)
require.Nil(t, account.Extra)
}
// TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra 测试保留已有的 Extra 数据
func TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra(t *testing.T) {
cache := &stubSchedulerCache{}
snapshotService := &SchedulerSnapshotService{cache: cache}
svc := &AntigravityGatewayService{
schedulerSnapshot: snapshotService,
}
account := &Account{
ID: 200,
Name: "test-account",
Platform: PlatformAntigravity,
Extra: map[string]any{
"existing_key": "existing_value",
"model_rate_limits": map[string]any{
"gemini-3-flash": map[string]any{
"rate_limited_at": "2024-01-01T00:00:00Z",
"rate_limit_reset_at": "2024-01-01T00:05:00Z",
},
},
},
}
svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second))
// 验证已有数据被保留
require.Equal(t, "existing_value", account.Extra["existing_key"])
limits := account.Extra["model_rate_limits"].(map[string]any)
require.NotNil(t, limits["gemini-3-flash"])
require.NotNil(t, limits["claude-sonnet-4-5"])
}
// TestSchedulerSnapshotService_UpdateAccountInCache 测试 UpdateAccountInCache 方法
func TestSchedulerSnapshotService_UpdateAccountInCache(t *testing.T) {
t.Run("calls cache.SetAccount", func(t *testing.T) {
cache := &stubSchedulerCache{}
svc := &SchedulerSnapshotService{cache: cache}
account := &Account{ID: 123, Name: "test"}
err := svc.UpdateAccountInCache(context.Background(), account)
require.NoError(t, err)
require.Len(t, cache.setAccountCalls, 1)
require.Equal(t, int64(123), cache.setAccountCalls[0].ID)
})
t.Run("returns nil when cache is nil", func(t *testing.T) {
svc := &SchedulerSnapshotService{cache: nil}
err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1})
require.NoError(t, err)
})
t.Run("returns nil when account is nil", func(t *testing.T) {
cache := &stubSchedulerCache{}
svc := &SchedulerSnapshotService{cache: cache}
err := svc.UpdateAccountInCache(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, cache.setAccountCalls)
})
t.Run("propagates cache error", func(t *testing.T) {
expectedErr := fmt.Errorf("cache error")
cache := &stubSchedulerCache{setAccountErr: expectedErr}
svc := &SchedulerSnapshotService{cache: cache}
err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1})
require.ErrorIs(t, err, expectedErr)
})
}
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
errors []error
callIdx int
calls []string
}
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
idx := m.callIdx
m.calls = append(m.calls, req.URL.String())
m.callIdx++
if idx < len(m.responses) {
return m.responses[idx], m.errors[idx]
}
return nil, nil
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return m.Do(req, proxyURL, accountID, accountConcurrency)
}
// TestHandleSmartRetry_URLLevelRateLimit 测试 URL 级别限流切换
func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test", "https://ag-2.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinueURL, result.action)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_LongDelay_ReturnsSwitchError 测试 retryDelay >= 阈值时返回 switchError
func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 15s >= 7s 阈值,应该返回 switchError
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for long delay")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_ShortDelay_SmartRetrySuccess 测试智能重试成功
func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.5s < 7s 阈值,应该触发智能重试
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.err)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 1, "should have made one retry call")
}
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
// 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次)
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp1 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp2 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp3 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp1, failResp2, failResp3},
errors: []error{nil, nil, nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 2,
Name: "acc-2",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError after smart retry failed")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-flash", result.switchError.RateLimitedModel)
require.False(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing.T) {
account := &Account{
ID: 4,
Name: "acc-4",
Type: AccountTypeAPIKey, // 非 Antigravity 平台账号
Platform: PlatformAnthropic,
}
// 即使是模型限流响应,非 OAuth 账号也应该走默认逻辑
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinue, result.action, "non-Antigravity platform account should continue default logic")
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic 测试非模型限流响应走默认逻辑
func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) {
account := &Account{
ID: 5,
Name: "acc-5",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 但没有 RATE_LIMIT_EXCEEDED 或 MODEL_CAPACITY_EXHAUSTED
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"}
],
"message": "Quota exceeded"
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinue, result.action, "non-model rate limit should continue default logic")
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError 测试刚好等于阈值时返回 switchError
func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 6,
Name: "acc-6",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 刚好 7s = 7s 阈值,应该返回 switchError
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.NotNil(t, result.switchError, "exactly at threshold should return switchError")
require.Equal(t, "gemini-pro", result.switchError.RateLimitedModel)
}
// TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates 测试 switchError 正确传播到上层
func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing.T) {
// 模拟 429 + 长延迟的响应
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"}
]
}
}`)
rateLimitResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{rateLimitResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 7,
Name: "acc-7",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result when switchError")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
// 第一次网络错误,第二次成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误)
errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发
}
account := &Account{
ID: 8,
Name: "acc-8",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.1s < 7s 阈值,应该触发智能重试
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 2, "should have made two retry calls")
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 9,
Name: "acc-9",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 + RATE_LIMIT_EXCEEDED + 无 retryDelay → 使用默认 1 分钟限流
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}
],
"message": "You have exhausted your capacity on this model."
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.NotNil(t, result.switchError, "should return switchError for no retryDelay")
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
//go:build unit
package service
import (
"testing"
)
func TestApplyThinkingModelSuffix(t *testing.T) {
tests := []struct {
name string
mappedModel string
thinkingEnabled bool
expected string
}{
// Thinking 未开启:保持原样
{
name: "thinking disabled - claude-sonnet-4-5 unchanged",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: "claude-sonnet-4-5",
},
{
name: "thinking disabled - other model unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: false,
expected: "claude-opus-4-6-thinking",
},
// Thinking 开启 + claude-sonnet-4-5:自动添加后缀
{
name: "thinking enabled - claude-sonnet-4-5 becomes thinking version",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
// Thinking 开启 + 其他模型:保持原样
{
name: "thinking enabled - claude-sonnet-4-5-thinking unchanged",
mappedModel: "claude-sonnet-4-5-thinking",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
{
name: "thinking enabled - claude-opus-4-6-thinking unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: true,
expected: "claude-opus-4-6-thinking",
},
{
name: "thinking enabled - gemini model unchanged",
mappedModel: "gemini-3-flash",
thinkingEnabled: true,
expected: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled)
if result != tt.expected {
t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q",
tt.mappedModel, tt.thinkingEnabled, result, tt.expected)
}
})
}
}
......@@ -42,7 +42,18 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", errors.New("upstream account missing api_key in credentials")
}
return apiKey, nil
}
if account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestAntigravityTokenProvider_GetAccessToken_Upstream(t *testing.T) {
provider := &AntigravityTokenProvider{}
t.Run("upstream account with valid api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{
"api_key": "sk-test-key-12345",
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sk-test-key-12345", token)
})
t.Run("upstream account missing api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
t.Run("upstream account with empty api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{
"api_key": "",
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
t.Run("upstream account with nil credentials", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
}
func TestAntigravityTokenProvider_GetAccessToken_Guards(t *testing.T) {
provider := &AntigravityTokenProvider{}
t.Run("nil account", func(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
})
t.Run("non-antigravity platform", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an antigravity account")
require.Empty(t, token)
})
t.Run("unsupported account type", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an antigravity oauth account")
require.Empty(t, token)
})
}
......@@ -35,6 +35,7 @@ type ConcurrencyCache interface {
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
......@@ -77,6 +78,11 @@ type AccountWithConcurrency struct {
MaxConcurrency int
}
type UserWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
......@@ -84,6 +90,13 @@ type AccountLoadInfo struct {
LoadRate int // 0-100+ (percent)
}
type UserLoadInfo struct {
UserID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
......@@ -253,6 +266,14 @@ func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// GetUsersLoadBatch returns load info for multiple users.
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
if s.cache == nil {
return map[int64]*UserLoadInfo{}, nil
}
return s.cache.GetUsersLoadBatch(ctx, users)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
......
//go:build unit
package service
import (
"context"
"testing"
)
func TestIsForceCacheBilling(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected bool
}{
{
name: "context without force cache billing",
ctx: context.Background(),
expected: false,
},
{
name: "context with force cache billing set to true",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, true),
expected: true,
},
{
name: "context with force cache billing set to false",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, false),
expected: false,
},
{
name: "context with wrong type value",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, "true"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsForceCacheBilling(tt.ctx)
if result != tt.expected {
t.Errorf("IsForceCacheBilling() = %v, want %v", result, tt.expected)
}
})
}
}
func TestWithForceCacheBilling(t *testing.T) {
ctx := context.Background()
// 原始上下文没有标记
if IsForceCacheBilling(ctx) {
t.Error("original context should not have force cache billing")
}
// 使用 WithForceCacheBilling 后应该有标记
newCtx := WithForceCacheBilling(ctx)
if !IsForceCacheBilling(newCtx) {
t.Error("new context should have force cache billing")
}
// 原始上下文应该不受影响
if IsForceCacheBilling(ctx) {
t.Error("original context should still not have force cache billing")
}
}
func TestForceCacheBilling_TokenConversion(t *testing.T) {
tests := []struct {
name string
forceCacheBilling bool
inputTokens int
cacheReadInputTokens int
expectedInputTokens int
expectedCacheReadTokens int
}{
{
name: "force cache billing converts input to cache_read",
forceCacheBilling: true,
inputTokens: 1000,
cacheReadInputTokens: 500,
expectedInputTokens: 0,
expectedCacheReadTokens: 1500, // 500 + 1000
},
{
name: "no force cache billing keeps tokens unchanged",
forceCacheBilling: false,
inputTokens: 1000,
cacheReadInputTokens: 500,
expectedInputTokens: 1000,
expectedCacheReadTokens: 500,
},
{
name: "force cache billing with zero input tokens does nothing",
forceCacheBilling: true,
inputTokens: 0,
cacheReadInputTokens: 500,
expectedInputTokens: 0,
expectedCacheReadTokens: 500,
},
{
name: "force cache billing with zero cache_read tokens",
forceCacheBilling: true,
inputTokens: 1000,
cacheReadInputTokens: 0,
expectedInputTokens: 0,
expectedCacheReadTokens: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 RecordUsage 中的 ForceCacheBilling 逻辑
usage := ClaudeUsage{
InputTokens: tt.inputTokens,
CacheReadInputTokens: tt.cacheReadInputTokens,
}
// 这是 RecordUsage 中的实际逻辑
if tt.forceCacheBilling && usage.InputTokens > 0 {
usage.CacheReadInputTokens += usage.InputTokens
usage.InputTokens = 0
}
if usage.InputTokens != tt.expectedInputTokens {
t.Errorf("InputTokens = %d, want %d", usage.InputTokens, tt.expectedInputTokens)
}
if usage.CacheReadInputTokens != tt.expectedCacheReadTokens {
t.Errorf("CacheReadInputTokens = %d, want %d", usage.CacheReadInputTokens, tt.expectedCacheReadTokens)
}
})
}
}
......@@ -216,6 +216,22 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil
}
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
......@@ -332,7 +348,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity)
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAntigravity)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
......@@ -670,7 +686,7 @@ func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *tes
cfg: testConfig(),
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-sonnet-4-5", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
......@@ -1014,11 +1030,17 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
expected bool
}{
{
name: "Antigravity平台-支持claude模型",
name: "Antigravity平台-支持默认映射中的claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
model: "claude-sonnet-4-5",
expected: true,
},
{
name: "Antigravity平台-不支持非默认映射中的claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
......@@ -1115,7 +1137,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
......@@ -1123,7 +1145,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) {
groupID := int64(30)
requestedModel := "claude-3-5-sonnet-20241022"
requestedModel := "claude-sonnet-4-5"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
......@@ -1168,7 +1190,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-路由粘性命中", func(t *testing.T) {
groupID := int64(31)
requestedModel := "claude-3-5-sonnet-20241022"
requestedModel := "claude-sonnet-4-5"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
......@@ -1320,7 +1342,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{
"claude-3-5-sonnet-20241022": map[string]any{
"rate_limit_reset_at": resetAt.Format(time.RFC3339),
},
},
......@@ -1465,7 +1487,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
......@@ -1597,7 +1619,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
......@@ -1870,6 +1892,19 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil
}
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users {
result[user.ID] = &UserLoadInfo{
UserID: user.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
......@@ -2747,7 +2782,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
Concurrency: 5,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{
"claude-3-5-sonnet-20241022": map[string]any{
"rate_limit_reset_at": now.Format(time.RFC3339),
},
},
......
......@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ParsedRequest 保存网关请求的预解析结果
......@@ -19,13 +21,14 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
......@@ -69,6 +72,13 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.Messages = messages
}
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
parsed.ThinkingEnabled = true
}
}
return parsed, nil
}
......@@ -466,7 +476,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" {
if signature != "" && signature != antigravity.DummyThoughtSignature {
newContent = append(newContent, block)
continue
}
......
......@@ -17,6 +17,15 @@ func TestParseGatewayRequest(t *testing.T) {
require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1)
require.False(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
......
......@@ -49,6 +49,29 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
// IsForceCacheBilling 检查是否启用强制缓存计费
func IsForceCacheBilling(ctx context.Context) bool {
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
return v
}
// WithForceCacheBilling 返回带有强制缓存计费标记的上下文
func WithForceCacheBilling(ctx context.Context) context.Context {
return context.WithValue(ctx, ForceCacheBillingContextKey, true)
}
func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on"
......@@ -250,6 +273,13 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
......@@ -265,6 +295,24 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息(uuid, accountID)
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
......@@ -275,16 +323,23 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
......@@ -294,6 +349,10 @@ func shouldClearStickySession(account *Account) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
return true
}
return false
}
......@@ -336,8 +395,9 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
}
func (e *UpstreamFailoverError) Error() string {
......@@ -470,6 +530,23 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
return accountID, nil
}
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息(uuid, accountID)
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
......@@ -968,6 +1045,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID
for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) {
filteredExcluded++
......@@ -986,12 +1064,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredPlatform++
continue
}
if !account.IsSchedulableForModel(requestedModel) {
filteredModelScope++
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) {
filteredModelMapping++
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
filteredModelMapping++
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
filteredModelScope++
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
}
// 窗口费用检查(非粘性会话路径)
......@@ -1006,6 +1085,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
if len(modelScopeSkippedIDs) > 0 {
log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v",
derefGroupID(groupID), requestedModel, modelScopeSkippedIDs)
}
}
if len(routingCandidates) > 0 {
......@@ -1017,8 +1100,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
......@@ -1075,10 +1158,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
// 3. 按负载感知排序
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var routingAvailable []accountWithLoad
for _, acc := range routingCandidates {
loadInfo := routingLoadMap[acc.ID]
......@@ -1169,14 +1248,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
account.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
......@@ -1234,10 +1313,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
// 窗口费用检查(非粘性会话路径)
......@@ -1265,10 +1344,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
......@@ -1283,47 +1362,108 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
available = newAvailable
}
}
}
......@@ -1740,6 +1880,106 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
// filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minPriority := accounts[0].account.Priority
for _, acc := range accounts[1:] {
if acc.account.Priority < minPriority {
minPriority = acc.account.Priority
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.account.Priority == minPriority {
result = append(result, acc)
}
}
return result
}
// filterByMinLoadRate 过滤出负载率最低的账号集合
func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minLoadRate := accounts[0].loadInfo.LoadRate
for _, acc := range accounts[1:] {
if acc.loadInfo.LoadRate < minLoadRate {
minLoadRate = acc.loadInfo.LoadRate
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.loadInfo.LoadRate == minLoadRate {
result = append(result, acc)
}
}
return result
}
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 1. 找到最小的 LastUsedAt(nil 被视为最小)
var minTime *time.Time
hasNil := false
for _, acc := range accounts {
if acc.account.LastUsedAt == nil {
hasNil = true
break
}
if minTime == nil || acc.account.LastUsedAt.Before(*minTime) {
minTime = acc.account.LastUsedAt
}
}
// 2. 收集所有具有最小 LastUsedAt 的账号索引
var candidateIdxs []int
for i, acc := range accounts {
if hasNil {
if acc.account.LastUsedAt == nil {
candidateIdxs = append(candidateIdxs, i)
}
} else {
if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) {
candidateIdxs = append(candidateIdxs, i)
}
}
}
// 3. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 5. 随机选择一个
selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))]
return &accounts[selectedIdx]
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
......@@ -1762,6 +2002,87 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
......@@ -1843,11 +2164,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
......@@ -1894,10 +2215,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -1946,11 +2267,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
......@@ -1986,10 +2307,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2056,11 +2377,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
......@@ -2109,10 +2430,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2161,11 +2482,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
......@@ -2203,10 +2524,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2250,11 +2571,38 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context)
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
if strings.TrimSpace(requestedModel) == "" {
return true
}
// 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底
mapped := mapAntigravityModel(account, requestedModel)
if mapped == "" {
return false
}
// 应用 thinking 后缀后检查最终模型是否在账号映射中
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
finalModel := applyThinkingModelSuffix(mapped, enabled)
if finalModel == mapped {
return true // thinking 后缀未改变模型名,映射已通过
}
return account.IsModelSupported(finalModel)
}
return true
}
return s.isModelSupportedByAccount(account, requestedModel)
}
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
if strings.TrimSpace(requestedModel) == "" {
return true
}
return mapAntigravityModel(account, requestedModel) != ""
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
......@@ -2268,13 +2616,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
......@@ -4162,14 +4503,15 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
......@@ -4185,6 +4527,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
......@@ -4345,6 +4696,7 @@ type RecordUsageLongContextInput struct {
IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService *APIKeyService // API Key 配额服务(可选)
}
......@@ -4356,6 +4708,15 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestGatewayService_isModelSupportedByAccount_AntigravityModelMapping(t *testing.T) {
svc := &GatewayService{}
// 使用 model_mapping 作为白名单(通配符匹配)
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
"gemini-3-*": "gemini-3-flash",
},
},
}
// claude-* 通配符匹配
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-opus-4-6"))
// gemini-3-* 通配符匹配
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-pro-high"))
// gemini-2.5-* 不匹配(不在 model_mapping 中)
require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-flash"))
require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro"))
// 其他平台模型不支持
require.False(t, svc.isModelSupportedByAccount(account, "gpt-4"))
// 空模型允许
require.True(t, svc.isModelSupportedByAccount(account, ""))
}
func TestGatewayService_isModelSupportedByAccount_AntigravityNoMapping(t *testing.T) {
svc := &GatewayService{}
// 未配置 model_mapping 时,使用默认映射(domain.DefaultAntigravityModelMapping)
// 只有默认映射中的模型才被支持
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{},
}
// 默认映射中的模型应该被支持
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5"))
// 不在默认映射中的模型不被支持
require.False(t, svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022"))
require.False(t, svc.isModelSupportedByAccount(account, "claude-unknown-model"))
// 非 claude-/gemini- 前缀仍然不支持
require.False(t, svc.isModelSupportedByAccount(account, "gpt-4"))
}
// TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode 测试 thinking 模式下的模型支持检查
// 验证调度时使用映射后的最终模型名(包括 thinking 后缀)来检查 model_mapping 支持
func TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
modelMapping map[string]any
requestedModel string
thinkingEnabled bool
expected bool
}{
// 场景 1: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=true
// mapAntigravityModel 找不到 claude-sonnet-4-5 的映射 → 返回 false
{
name: "thinking_enabled_no_base_mapping_returns_false",
modelMapping: map[string]any{
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: false,
},
// 场景 2: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=false
// mapAntigravityModel 找不到 claude-sonnet-4-5 的映射 → 返回 false
{
name: "thinking_disabled_no_base_mapping_returns_false",
modelMapping: map[string]any{
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: false,
},
// 场景 3: 配置 claude-sonnet-4-5(非 thinking),请求 claude-sonnet-4-5 + thinking=true
// 最终模型名 = claude-sonnet-4-5-thinking,不在 mapping 中,应该不匹配
{
name: "thinking_enabled_no_match_non_thinking_mapping",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: false,
},
// 场景 4: 配置两种模型,请求 claude-sonnet-4-5 + thinking=true,应该匹配 thinking 版本
{
name: "both_models_thinking_enabled_matches_thinking",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true,
},
// 场景 5: 配置两种模型,请求 claude-sonnet-4-5 + thinking=false,应该匹配非 thinking 版本
{
name: "both_models_thinking_disabled_matches_non_thinking",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: true,
},
// 场景 6: 通配符 claude-* 应该同时匹配 thinking 和非 thinking
{
name: "wildcard_matches_thinking",
modelMapping: map[string]any{
"claude-*": "claude-sonnet-4-5",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true, // claude-sonnet-4-5-thinking 匹配 claude-*
},
// 场景 7: 只配置 thinking 变体但没有基础模型映射 → 返回 false
// mapAntigravityModel 找不到 claude-opus-4-6 的映射
{
name: "opus_thinking_no_base_mapping_returns_false",
modelMapping: map[string]any{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
},
requestedModel: "claude-opus-4-6",
thinkingEnabled: true,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": tt.modelMapping,
},
}
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, tt.thinkingEnabled)
result := svc.isModelSupportedByAccountWithContext(ctx, account, tt.requestedModel)
require.Equal(t, tt.expected, result,
"isModelSupportedByAccountWithContext(ctx[thinking=%v], account, %q) = %v, want %v",
tt.thinkingEnabled, tt.requestedModel, result, tt.expected)
})
}
}
// TestGatewayService_isModelSupportedByAccount_CustomMappingNotInDefault 测试自定义模型映射中
// 不在 DefaultAntigravityModelMapping 中的模型能通过调度
func TestGatewayService_isModelSupportedByAccount_CustomMappingNotInDefault(t *testing.T) {
svc := &GatewayService{}
// 自定义映射中包含不在默认映射中的模型
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"my-custom-model": "actual-upstream-model",
"gpt-4o": "some-upstream-model",
"llama-3-70b": "llama-3-70b-upstream",
"claude-sonnet-4-5": "claude-sonnet-4-5",
},
},
}
// 自定义模型应该通过(不在 DefaultAntigravityModelMapping 中也可以)
require.True(t, svc.isModelSupportedByAccount(account, "my-custom-model"))
require.True(t, svc.isModelSupportedByAccount(account, "gpt-4o"))
require.True(t, svc.isModelSupportedByAccount(account, "llama-3-70b"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
// 不在自定义映射中的模型不通过
require.False(t, svc.isModelSupportedByAccount(account, "gpt-3.5-turbo"))
require.False(t, svc.isModelSupportedByAccount(account, "unknown-model"))
// 空模型允许
require.True(t, svc.isModelSupportedByAccount(account, ""))
}
// TestGatewayService_isModelSupportedByAccountWithContext_CustomMappingThinking
// 测试自定义映射 + thinking 模式的交互
func TestGatewayService_isModelSupportedByAccountWithContext_CustomMappingThinking(t *testing.T) {
svc := &GatewayService{}
// 自定义映射同时配置基础模型和 thinking 变体
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"my-custom-model": "upstream-model",
},
},
}
// thinking=true: claude-sonnet-4-5 → mapped=claude-sonnet-4-5 → +thinking → check IsModelSupported(claude-sonnet-4-5-thinking)=true
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "claude-sonnet-4-5"))
// thinking=false: claude-sonnet-4-5 → mapped=claude-sonnet-4-5 → check IsModelSupported(claude-sonnet-4-5)=true
ctx = context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false)
require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "claude-sonnet-4-5"))
// 自定义模型(非 claude)不受 thinking 后缀影响,mapped 成功即通过
ctx = context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "my-custom-model"))
}
......@@ -200,7 +200,7 @@ func (s *GeminiMessagesCompatService) tryStickySessionHit(
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
if shouldClearStickySession(account, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
......@@ -230,7 +230,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false
}
......@@ -362,7 +362,10 @@ func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
if strings.TrimSpace(requestedModel) == "" {
return true
}
return mapAntigravityModel(account, requestedModel) != ""
}
return account.IsModelSupported(requestedModel)
}
......@@ -2658,7 +2661,9 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil {
ts := time.Now().Unix() + int64(dur.Seconds())
// Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
// which can affect scheduling decisions around thresholds (like 10s).
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts
}
}
......
......@@ -265,6 +265,22 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil
}
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
......@@ -880,7 +896,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
model: "claude-sonnet-4-5",
expected: true,
},
{
......@@ -889,6 +905,39 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
model: "gpt-4",
expected: false,
},
{
name: "Antigravity平台-空模型允许",
account: &Account{Platform: PlatformAntigravity},
model: "",
expected: true,
},
{
name: "Antigravity平台-自定义映射-支持自定义模型",
account: &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"my-custom-model": "upstream-model",
"gpt-4o": "some-model",
},
},
},
model: "my-custom-model",
expected: true,
},
{
name: "Antigravity平台-自定义映射-不在映射中的模型不支持",
account: &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"my-custom-model": "upstream-model",
},
},
},
model: "claude-sonnet-4-5",
expected: false,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini},
......
package service
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
h := xxhash.Sum64(data)
return strconv.FormatUint(h, 36)
}
// BuildGeminiDigestChain 根据 Gemini 请求生成摘要链
// 格式: s:<hash>-u:<hash>-m:<hash>-u:<hash>-...
// s = systemInstruction, u = user, m = model
func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
if req == nil {
return ""
}
var parts []string
// 1. system instruction
if req.SystemInstruction != nil && len(req.SystemInstruction.Parts) > 0 {
partsData, _ := json.Marshal(req.SystemInstruction.Parts)
parts = append(parts, "s:"+shortHash(partsData))
}
// 2. contents
for _, c := range req.Contents {
prefix := "u" // user
if c.Role == "model" {
prefix = "m"
}
partsData, _ := json.Marshal(c.Parts)
parts = append(parts, prefix+":"+shortHash(partsData))
}
return strings.Join(parts, "-")
}
// GenerateGeminiPrefixHash 生成前缀 hash(用于分区隔离)
// 组合: userID + apiKeyID + ip + userAgent + platform + model
// 返回 16 字符的 Base64 编码的 SHA256 前缀
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
// 组合所有标识符
combined := strconv.FormatInt(userID, 10) + ":" +
strconv.FormatInt(apiKeyID, 10) + ":" +
ip + ":" +
userAgent + ":" +
platform + ":" +
model
hash := sha256.Sum256([]byte(combined))
// 取前 12 字节,Base64 编码后正好 16 字符
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
if value == "" {
return "", 0, false
}
// 找到最后一个 ":" 的位置(因为 uuid 可能包含 ":")
i := strings.LastIndex(value, ":")
if i <= 0 || i >= len(value)-1 {
return "", 0, false
}
uuid = value[:i]
accountID, err := strconv.ParseInt(value[i+1:], 10, 64)
if err != nil {
return "", 0, false
}
return uuid, accountID, true
}
// FormatGeminiSessionValue 格式化 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func FormatGeminiSessionValue(uuid string, accountID int64) string {
return uuid + ":" + strconv.FormatInt(accountID, 10)
}
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
func GenerateGeminiDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return geminiDigestSessionKeyPrefix + prefix + ":" + uuidPart
}
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
accountID := int64(100)
// 模拟第一轮对话
req1 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}},
},
}
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
if foundUUID != sessionUUID {
t.Errorf("Round 2: expected UUID %s, got %s", sessionUUID, foundUUID)
}
if foundAccID != accountID {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I can help with coding, writing, and more!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Great, help me write some Go code"}}},
},
}
chain3 := BuildGeminiDigestChain(req3)
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
if foundUUID != sessionUUID {
t.Errorf("Round 3: expected UUID %s, got %s", sessionUUID, foundUUID)
}
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 第一个会话
req1 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Tell me about Go programming"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What's the weather today?"}}},
},
}
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
// 查找应该返回最长匹配(账号 3)
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
func TestShortHash(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"empty", []byte{}},
{"simple", []byte("hello world")},
{"json", []byte(`{"role":"user","parts":[{"text":"hello"}]}`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shortHash(tt.input)
// Base36 编码的 uint64 最长 13 个字符
if len(result) > 13 {
t.Errorf("shortHash result too long: %d characters", len(result))
}
// 相同输入应该产生相同输出
result2 := shortHash(tt.input)
if result != result2 {
t.Errorf("shortHash not deterministic: %s vs %s", result, result2)
}
})
}
}
func TestBuildGeminiDigestChain(t *testing.T) {
tests := []struct {
name string
req *antigravity.GeminiRequest
wantLen int // 预期的分段数量
hasEmpty bool // 是否应该是空字符串
}{
{
name: "nil request",
req: nil,
hasEmpty: true,
},
{
name: "empty contents",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{},
},
hasEmpty: true,
},
{
name: "single user message",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
},
wantLen: 1, // u:<hash>
},
{
name: "user and model messages",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi there"}}},
},
},
wantLen: 2, // u:<hash>-m:<hash>
},
{
name: "with system instruction",
req: &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Role: "user",
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
},
wantLen: 2, // s:<hash>-u:<hash>
},
{
name: "conversation with system",
req: &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Role: "user",
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "how are you?"}}},
},
},
wantLen: 4, // s:<hash>-u:<hash>-m:<hash>-u:<hash>
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildGeminiDigestChain(tt.req)
if tt.hasEmpty {
if result != "" {
t.Errorf("expected empty string, got: %s", result)
}
return
}
// 检查分段数量
parts := splitChain(result)
if len(parts) != tt.wantLen {
t.Errorf("expected %d parts, got %d: %s", tt.wantLen, len(parts), result)
}
// 验证每个分段的格式
for _, part := range parts {
if len(part) < 3 || part[1] != ':' {
t.Errorf("invalid part format: %s", part)
}
prefix := part[0]
if prefix != 's' && prefix != 'u' && prefix != 'm' {
t.Errorf("invalid prefix: %c", prefix)
}
}
})
}
}
func TestGenerateGeminiPrefixHash(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
hash3 := GenerateGeminiPrefixHash(2, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
// 相同输入应该产生相同输出
if hash1 != hash2 {
t.Errorf("GenerateGeminiPrefixHash not deterministic: %s vs %s", hash1, hash2)
}
// 不同输入应该产生不同输出
if hash1 == hash3 {
t.Errorf("GenerateGeminiPrefixHash collision for different inputs")
}
// Base64 URL 编码的 12 字节正好是 16 字符
if len(hash1) != 16 {
t.Errorf("expected 16 characters, got %d: %s", len(hash1), hash1)
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
value string
wantUUID string
wantAccID int64
wantOK bool
}{
{
name: "empty",
value: "",
wantOK: false,
},
{
name: "no colon",
value: "abc123",
wantOK: false,
},
{
name: "valid",
value: "uuid-1234:100",
wantUUID: "uuid-1234",
wantAccID: 100,
wantOK: true,
},
{
name: "uuid with colon",
value: "a:b:c:123",
wantUUID: "a:b:c",
wantAccID: 123,
wantOK: true,
},
{
name: "invalid account id",
value: "uuid:abc",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uuid, accID, ok := ParseGeminiSessionValue(tt.value)
if ok != tt.wantOK {
t.Errorf("ok: expected %v, got %v", tt.wantOK, ok)
}
if tt.wantOK {
if uuid != tt.wantUUID {
t.Errorf("uuid: expected %s, got %s", tt.wantUUID, uuid)
}
if accID != tt.wantAccID {
t.Errorf("accountID: expected %d, got %d", tt.wantAccID, accID)
}
}
})
}
}
func TestFormatGeminiSessionValue(t *testing.T) {
result := FormatGeminiSessionValue("test-uuid", 123)
expected := "test-uuid:123"
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
// 验证往返一致性
uuid, accID, ok := ParseGeminiSessionValue(result)
if !ok {
t.Error("ParseGeminiSessionValue failed on formatted value")
}
if uuid != "test-uuid" || accID != 123 {
t.Errorf("round-trip failed: uuid=%s, accID=%d", uuid, accID)
}
}
// splitChain 辅助函数:按 "-" 分割摘要链
func splitChain(chain string) []string {
if chain == "" {
return nil
}
var parts []string
start := 0
for i := 0; i < len(chain); i++ {
if chain[i] == '-' {
parts = append(parts, chain[start:i])
start = i + 1
}
}
if start < len(chain) {
parts = append(parts, chain[start:])
}
return parts
}
func TestDigestChainDifferentSysInstruction(t *testing.T) {
req1 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "SYS_ORIGINAL"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
}
req2 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "SYS_MODIFIED"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
if chain1 == chain2 {
t.Error("Different systemInstruction should produce different chains")
}
}
func TestDigestChainTamperedMiddleContent(t *testing.T) {
req1 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "ORIGINAL_REPLY"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}},
},
}
req2 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "TAMPERED_REPLY"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
if chain1 == chain2 {
t.Error("Tampered middle content should produce different chains")
}
// 验证第一个 user 的 hash 相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
if parts1[1] == parts2[1] {
t.Error("Model reply hash should be different")
}
}
func TestGenerateGeminiDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "gemini:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars prefix and uuid",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "gemini:digest:12345678:abcdefgh",
},
{
name: "short hash and short uuid (less than 8)",
prefixHash: "abc",
uuid: "xyz",
want: "gemini:digest:abc:xyz",
},
{
name: "empty hash and uuid",
prefixHash: "",
uuid: "",
want: "gemini:digest::",
},
{
name: "normal prefix with short uuid",
prefixHash: "abcdefgh12345678",
uuid: "short",
want: "gemini:digest:abcdefgh:short",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateGeminiDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateGeminiDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证确定性:相同输入产生相同输出
t.Run("deterministic", func(t *testing.T) {
hash := "testprefix123456"
uuid := "test-uuid-12345"
result1 := GenerateGeminiDigestSessionKey(hash, uuid)
result2 := GenerateGeminiDigestSessionKey(hash, uuid)
if result1 != result2 {
t.Errorf("GenerateGeminiDigestSessionKey not deterministic: %s vs %s", result1, result2)
}
})
// 验证不同 uuid 产生不同 sessionKey(负载均衡核心逻辑)
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
uuid1 := "uuid0001-session-a"
uuid2 := "uuid0002-session-b"
result1 := GenerateGeminiDigestSessionKey(hash, uuid1)
result2 := GenerateGeminiDigestSessionKey(hash, uuid2)
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
package service
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
const modelRateLimitsKey = "model_rate_limits"
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
model := strings.ToLower(strings.TrimSpace(requestedModel))
if model == "" {
return "", false
// isRateLimitActiveForKey 检查指定 key 的限流是否生效
func (a *Account) isRateLimitActiveForKey(key string) bool {
resetAt := a.modelRateLimitResetAt(key)
return resetAt != nil && time.Now().Before(*resetAt)
}
// getRateLimitRemainingForKey 获取指定 key 的限流剩余时间,0 表示未限流或已过期
func (a *Account) getRateLimitRemainingForKey(key string) time.Duration {
resetAt := a.modelRateLimitResetAt(key)
if resetAt == nil {
return 0
}
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "sonnet") {
return modelRateLimitScopeClaudeSonnet, true
remaining := time.Until(*resetAt)
if remaining > 0 {
return remaining
}
return "", false
return 0
}
func (a *Account) isModelRateLimited(requestedModel string) bool {
scope, ok := resolveModelRateLimitScope(requestedModel)
if !ok {
func (a *Account) isModelRateLimitedWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
resetAt := a.modelRateLimitResetAt(scope)
if resetAt == nil {
modelKey := a.GetMappedModel(requestedModel)
if a.Platform == PlatformAntigravity {
modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel)
}
modelKey = strings.TrimSpace(modelKey)
if modelKey == "" {
return false
}
return time.Now().Before(*resetAt)
return a.isRateLimitActiveForKey(modelKey)
}
// GetModelRateLimitRemainingTime 获取模型限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetModelRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetModelRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
func (a *Account) GetModelRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelKey := a.GetMappedModel(requestedModel)
if a.Platform == PlatformAntigravity {
modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel)
}
modelKey = strings.TrimSpace(modelKey)
if modelKey == "" {
return 0
}
return a.getRateLimitRemainingForKey(modelKey)
}
func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requestedModel string) string {
modelKey := mapAntigravityModel(account, requestedModel)
if modelKey == "" {
return ""
}
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking)
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
modelKey = applyThinkingModelSuffix(modelKey, enabled)
}
return modelKey
}
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment