Commit 5e98445b authored by erio's avatar erio
Browse files

feat(antigravity): comprehensive enhancements - model mapping, rate limiting, scheduling & ops

Key changes:
- Upgrade model mapping: Opus 4.5 → Opus 4.6-thinking with precise matching
- Unified rate limiting: scope-level → model-level with Redis snapshot sync
- Load-balanced scheduling by call count with smart retry mechanism
- Force cache billing support
- Model identity injection in prompts with leak prevention
- Thinking mode auto-handling (max_tokens/budget_tokens fix)
- Frontend: whitelist mode toggle, model mapping validation, status indicators
- Gemini session fallback with Redis Trie O(L) matching
- Ops: enhanced concurrency monitoring, account availability, retry logic
- Migration scripts: 049-051 for model mapping unification
parent e617b45b
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -113,7 +114,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { ...@@ -113,7 +114,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
c, _ := gin.CreateTestContext(writer) c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{ body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-5", "model": "claude-opus-4-6",
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "hi"}, {"role": "user", "content": "hi"},
}, },
...@@ -149,7 +150,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { ...@@ -149,7 +150,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
}, },
} }
result, err := svc.Forward(context.Background(), c, account, body) result, err := svc.Forward(context.Background(), c, account, body, false)
require.Nil(t, result) require.Nil(t, result)
var promptErr *PromptTooLongError var promptErr *PromptTooLongError
...@@ -166,27 +167,227 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { ...@@ -166,27 +167,227 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
require.Equal(t, "prompt_too_long", events[0].Kind) require.Equal(t, "prompt_too_long", events[0].Kind)
} }
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) { // TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover
t.Setenv(antigravityMaxRetriesEnv, "4") // 验证:当账号存在模型限流且剩余时间 >= antigravityRateLimitThreshold 时,
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7") // Forward 方法应返回 UpstreamFailoverError,触发 Handler 切换账号
t.Setenv(antigravityMaxRetriesClaudeEnv, "") func TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover(t *testing.T) {
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") gin.SetMode(gin.TestMode)
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 1,
Name: "acc-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false) result, err := svc.Forward(context.Background(), c, account, body, false)
require.Equal(t, 4, got) require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true) // 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
require.Equal(t, 7, got) var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求,ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
} }
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) { // TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover
t.Setenv(antigravityMaxRetriesEnv, "5") // 验证:ForwardGemini 方法同样能正确将 AntigravityAccountSwitchError 转换为 UpstreamFailoverError
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "") func TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover(t *testing.T) {
t.Setenv(antigravityMaxRetriesClaudeEnv, "") gin.SetMode(gin.TestMode)
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") writer := httptest.NewRecorder()
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 2,
Name: "acc-gemini-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求,ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
// TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling
// 验证:粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]string{{"role": "user", "content": "hello"}},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 3,
Name: "acc-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.Forward(context.Background(), c, account, body, true)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 4,
Name: "acc-gemini-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, true)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true) // 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
require.Equal(t, 5, got) var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
} }
...@@ -14,32 +14,28 @@ func TestIsAntigravityModelSupported(t *testing.T) { ...@@ -14,32 +14,28 @@ func TestIsAntigravityModelSupported(t *testing.T) {
model string model string
expected bool expected bool
}{ }{
// 直接支持的模型 // 在默认映射中的模型(支持)
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true}, {"默认映射 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true}, {"默认映射 - claude-opus-4-6-thinking", "claude-opus-4-6-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true}, {"默认映射 - claude-opus-4-6", "claude-opus-4-6", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true}, {"默认映射 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true}, {"默认映射 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true}, {"默认映射 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"默认映射 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
// 可映射的模型 {"默认映射 - gemini-3-pro-high", "gemini-3-pro-high", true},
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true}, {"默认映射 - claude-haiku-4-5", "claude-haiku-4-5", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// Claude 前缀兜底 // 不在默认映射中的模型(不支持)
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true}, {"未配置 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", false},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true}, {"未配置 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", false},
{"Claude前缀 - claude-future-version", "claude-future-version", true}, {"未配置 - claude-3-haiku-20240307", "claude-3-haiku-20240307", false},
{"未配置 - gemini-unknown-model", "gemini-unknown-model", false},
{"未配置 - gemini-future-version", "gemini-future-version", false},
{"未配置 - claude-unknown-model", "claude-unknown-model", false},
{"未配置 - claude-3-opus-20240229", "claude-3-opus-20240229", false},
{"未配置 - claude-future-version", "claude-future-version", false},
// 不支持的模型 // 非 Claude/Gemini 模型(不支持)
{"不支持 - gpt-4", "gpt-4", false}, {"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false}, {"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false}, {"不支持 - llama-3", "llama-3", false},
...@@ -64,7 +60,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -64,7 +60,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
accountMapping map[string]string accountMapping map[string]string
expected string expected string
}{ }{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any) // 1. 账户级映射优先
{ {
name: "账户映射优先", name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022", requestedModel: "claude-3-5-sonnet-20241022",
...@@ -72,120 +68,124 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -72,120 +68,124 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "custom-model", expected: "custom-model",
}, },
{ {
name: "账户映射覆盖系统映射", name: "账户映射 - 可覆盖默认映射的模型",
requestedModel: "claude-sonnet-4-5",
accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"},
expected: "my-custom-sonnet",
},
{
name: "账户映射 - 可覆盖未知模型",
requestedModel: "claude-opus-4", requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"}, accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus", expected: "my-opus",
}, },
// 2. 系统默认映射 // 2. 默认映射(DefaultAntigravityModelMapping)
{ {
name: "系统映射 - claude-3-5-sonnet-20241022", name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking",
requestedModel: "claude-3-5-sonnet-20241022", requestedModel: "claude-opus-4-6",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-opus-4-6-thinking",
}, },
{ {
name: "系统映射 - claude-3-5-sonnet-20240620", name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking",
requestedModel: "claude-3-5-sonnet-20240620", requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-opus-4-6-thinking",
}, },
{ {
name: "系统映射 - claude-opus-4", name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4", requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil, accountMapping: nil,
expected: "claude-opus-4-5-thinking", expected: "claude-opus-4-6-thinking",
}, },
{ {
name: "系统映射 - claude-opus-4-5-20251101", name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-opus-4-5-20251101", requestedModel: "claude-haiku-4-5",
accountMapping: nil, accountMapping: nil,
expected: "claude-opus-4-5-thinking", expected: "claude-sonnet-4-5",
}, },
{ {
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5", name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4", requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
{ {
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5", name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5", requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
// 3. 默认映射中的透传(映射到自己)
{ {
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5", name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-3-haiku-20240307", requestedModel: "claude-sonnet-4-5",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
{ {
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", name: "默认映射透传 - claude-opus-4-6-thinking",
requestedModel: "claude-haiku-4-5-20251001", requestedModel: "claude-opus-4-6-thinking",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-opus-4-6-thinking",
}, },
{ {
name: "系统映射 - claude-sonnet-4-5-20250929", name: "默认映射透传 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-20250929", requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5-thinking",
}, },
// 3. Gemini 2.5 → 3 映射
{ {
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash", name: "默认映射透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash", requestedModel: "gemini-2.5-flash",
accountMapping: nil, accountMapping: nil,
expected: "gemini-3-flash", expected: "gemini-2.5-flash",
}, },
{ {
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high", name: "默认映射透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro", requestedModel: "gemini-2.5-pro",
accountMapping: nil, accountMapping: nil,
expected: "gemini-3-pro-high", expected: "gemini-2.5-pro",
}, },
{ {
name: "Gemini透传 - gemini-future-model", name: "默认映射透传 - gemini-3-flash",
requestedModel: "gemini-future-model", requestedModel: "gemini-3-flash",
accountMapping: nil, accountMapping: nil,
expected: "gemini-future-model", expected: "gemini-3-flash",
}, },
// 4. 直接支持的模型 // 4. 未在默认映射中的模型返回空字符串(不支持)
{ {
name: "直接支持 - claude-sonnet-4-5", name: "未知模型 - claude-unknown 返回空",
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-unknown",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "",
}, },
{ {
name: "直接支持 - claude-opus-4-5-thinking", name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)",
requestedModel: "claude-opus-4-5-thinking", requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil, accountMapping: nil,
expected: "claude-opus-4-5-thinking", expected: "",
}, },
{ {
name: "直接支持 - claude-sonnet-4-5-thinking", name: "未知模型 - claude-3-opus-20240229 返回空",
requestedModel: "claude-sonnet-4-5-thinking", requestedModel: "claude-3-opus-20240229",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5-thinking", expected: "",
}, },
// 5. 默认值 fallback(未知 claude 模型)
{ {
name: "默认值 - claude-unknown", name: "未知模型 - claude-opus-4 返回空",
requestedModel: "claude-unknown", requestedModel: "claude-opus-4",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "",
}, },
{ {
name: "默认值 - claude-3-opus-20240229", name: "未知模型 - gemini-future-model 返回空",
requestedModel: "claude-3-opus-20240229", requestedModel: "gemini-future-model",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "",
}, },
} }
...@@ -219,12 +219,10 @@ func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) { ...@@ -219,12 +219,10 @@ func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
requestedModel string requestedModel string
expected string expected string
}{ }{
// 空字符串回退到默认值 // 空字符串和非 claude/gemini 前缀返回空字符串
{"空字符串", "", "claude-sonnet-4-5"}, {"空字符串", "", ""},
{"非claude/gemini前缀 - gpt", "gpt-4", ""},
// 非 claude/gemini 前缀回退到默认值 {"非claude/gemini前缀 - llama", "llama-3", ""},
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
} }
for _, tt := range tests { for _, tt := range tests {
...@@ -248,10 +246,10 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) { ...@@ -248,10 +246,10 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true}, {"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true}, {"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射 // 可映射(有明确前缀映射)
{"可映射 - claude-opus-4", "claude-opus-4", true}, {"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
// 前缀透传 // 前缀透传(claude 和 gemini 前缀)
{"Gemini前缀", "gemini-unknown", true}, {"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true}, {"Claude前缀", "claude-unknown", true},
......
package service package service
import ( import (
"context"
"slices" "slices"
"strings" "strings"
"time" "time"
...@@ -57,15 +58,20 @@ func normalizeAntigravityModelName(model string) string { ...@@ -57,15 +58,20 @@ func normalizeAntigravityModelName(model string) string {
return normalized return normalized
} }
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度 // IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool { func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
}
func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool {
if a == nil { if a == nil {
return false return false
} }
if !a.IsSchedulable() { if !a.IsSchedulable() {
return false return false
} }
if a.isModelRateLimited(requestedModel) { if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false return false
} }
if a.Platform != PlatformAntigravity { if a.Platform != PlatformAntigravity {
...@@ -132,3 +138,43 @@ func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 { ...@@ -132,3 +138,43 @@ func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
} }
return result return result
} }
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
}
...@@ -21,6 +21,23 @@ type stubAntigravityUpstream struct { ...@@ -21,6 +21,23 @@ type stubAntigravityUpstream struct {
calls []string calls []string
} }
type recordingOKUpstream struct {
calls int
}
func (r *recordingOKUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
r.calls++
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, accountConcurrency)
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String() url := req.URL.String()
s.calls = append(s.calls, url) s.calls = append(s.calls, url)
...@@ -53,10 +70,17 @@ type rateLimitCall struct { ...@@ -53,10 +70,17 @@ type rateLimitCall struct {
resetAt time.Time resetAt time.Time
} }
type modelRateLimitCall struct {
accountID int64
modelKey string // 存储的 key(应该是官方模型 ID,如 "claude-sonnet-4-5")
resetAt time.Time
}
type stubAntigravityAccountRepo struct { type stubAntigravityAccountRepo struct {
AccountRepository AccountRepository
scopeCalls []scopeLimitCall scopeCalls []scopeLimitCall
rateCalls []rateLimitCall rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
} }
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
...@@ -69,6 +93,11 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6 ...@@ -69,6 +93,11 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6
return nil return nil
} }
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error {
s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability oldAvailability := antigravity.DefaultURLAvailability
...@@ -94,17 +123,19 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { ...@@ -94,17 +123,19 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var handleErrorCalled bool var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{ result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]", prefix: "[test]",
ctx: context.Background(), ctx: context.Background(),
account: account, account: account,
proxyURL: "", proxyURL: "",
accessToken: "token", accessToken: "token",
action: "generateContent", action: "generateContent",
body: []byte(`{"input":"test"}`), body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude, quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream, httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true handleErrorCalled = true
return nil
}, },
}) })
...@@ -123,14 +154,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { ...@@ -123,14 +154,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.Equal(t, base2, available[0]) require.Equal(t, base2, available[0])
} }
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) { func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true") // 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{} repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo} svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s") body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1) require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls) require.Empty(t, repo.rateCalls)
...@@ -140,20 +171,122 @@ func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) ...@@ -140,20 +171,122 @@ func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
} }
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) { // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
t.Setenv(antigravityScopeRateLimitEnv, "false") func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 1, Name: "acc-1", Platform: PlatformAntigravity}
// 429 + RATE_LIMIT_EXCEEDED + 模型名 → 模型限流
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "claude-sonnet-4-5", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流)
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 4, Name: "acc-4", Platform: PlatformAntigravity}
// 503 + 普通错误(非 MODEL_CAPACITY_EXHAUSTED)→ 不做任何处理
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
"message": "Service temporarily unavailable",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "SERVICE_UNAVAILABLE"}
]
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
// TestHandleUpstreamError_503_EmptyBody 测试 503 空响应体(不处理)
func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
repo := &stubAntigravityAccountRepo{} repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo} svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity} account := &Account{ID: 5, Name: "acc-5", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s") // 503 + 空响应体 → 不做任何处理
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) body := []byte(`{}`)
require.Len(t, repo.rateCalls, 1) result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls) require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0] require.Empty(t, repo.rateCalls)
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
} }
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
...@@ -188,3 +321,751 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { ...@@ -188,3 +321,751 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
func buildGeminiRateLimitBody(delay string) []byte { func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay)) return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
} }
func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
// Avoid flakiness around Unix second boundaries.
for {
now := time.Now()
if now.Nanosecond() < 800*1e6 {
break
}
time.Sleep(5 * time.Millisecond)
}
baseUnix := time.Now().Unix()
ts := ParseGeminiRateLimitResetTime(buildGeminiRateLimitBody("0.1s"))
require.NotNil(t, ts)
require.Equal(t, baseUnix+1, *ts, "fractional seconds should be rounded up to the next second")
}
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct {
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
}{
{
name: "valid complete response with RATE_LIMIT_EXCEEDED",
body: `{
"error": {
"code": 429,
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"domain": "cloudcode-pa.googleapis.com",
"metadata": {
"model": "claude-sonnet-4-5",
"quotaResetDelay": "201.506475ms"
},
"reason": "RATE_LIMIT_EXCEEDED"
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "0.201506475s"
}
],
"message": "You have exhausted your capacity on this model.",
"status": "RESOURCE_EXHAUSTED"
}
}`,
expectedDelay: 201506475 * time.Nanosecond,
expectedModel: "claude-sonnet-4-5",
},
{
name: "429 RESOURCE_EXHAUSTED without RATE_LIMIT_EXCEEDED - should return nil",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"metadata": {"model": "claude-sonnet-4-5"},
"reason": "QUOTA_EXCEEDED"
},
{
"@type": "type.googleapis.com/google.rpc.RetryInfo",
"retryDelay": "3s"
}
]
}
}`,
expectedNil: true,
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay",
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`,
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "SERVICE_UNAVAILABLE"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"}
]
}
}`,
expectedNil: true,
},
{
name: "wrong status - should return nil",
body: `{
"error": {
"code": 429,
"status": "INVALID_ARGUMENT",
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "missing status - should return nil",
body: `{
"error": {
"code": 429,
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "milliseconds format is now supported",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test-model"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "500ms"}
]
}
}`,
expectedDelay: 500 * time.Millisecond,
expectedModel: "test-model",
},
{
name: "minutes format is supported",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "4m50s"}
]
}
}`,
expectedDelay: 4*time.Minute + 50*time.Second,
expectedModel: "gemini-3-pro",
},
{
name: "missing model name - should return nil",
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedNil: true,
},
{
name: "invalid JSON",
body: `not json`,
expectedNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseAntigravitySmartRetryInfo([]byte(tt.body))
if tt.expectedNil {
if result != nil {
t.Errorf("expected nil, got %+v", result)
}
return
}
if result == nil {
t.Errorf("expected non-nil result")
return
}
if result.RetryDelay != tt.expectedDelay {
t.Errorf("RetryDelay = %v, want %v", result.RetryDelay, tt.expectedDelay)
}
if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
}
})
}
}
func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
oauthAccount := &Account{Type: AccountTypeOAuth}
setupTokenAccount := &Account{Type: AccountTypeSetupToken}
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct {
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
minWait time.Duration
modelName string
}{
{
name: "OAuth account with short delay (< 7s) - smart retry",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 1 * time.Second, // 0.5s < 1s, 使用最小等待时间 1s
modelName: "claude-opus-4",
},
{
name: "SetupToken account with short delay - smart retry",
account: setupTokenAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "3s"}
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
minWait: 3 * time.Second,
modelName: "gemini-3-flash",
},
{
name: "OAuth account with long delay (>= 7s) - direct rate limit",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "claude-sonnet-4-5",
},
{
name: "API Key account - should not trigger",
account: apiKeyAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "test"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: false,
},
{
name: "OAuth account with exactly 7s delay - direct rate limit",
account: oauthAccount,
body: `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-pro",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - long delay",
account: oauthAccount,
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-2.5-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}
],
"message": "No capacity available for model gemini-2.5-flash on the server"
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-2.5-flash",
},
{
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}
],
"message": "You have exhausted your capacity on this model."
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
}
if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
}
if shouldRetry {
if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
})
}
}
// TestSetModelRateLimitByModelName_UsesOfficialModelID 验证写入端使用官方模型 ID
func TestSetModelRateLimitByModelName_UsesOfficialModelID(t *testing.T) {
tests := []struct {
name string
modelName string
expectedModelKey string
expectedSuccess bool
}{
{
name: "claude-sonnet-4-5 should be stored as-is",
modelName: "claude-sonnet-4-5",
expectedModelKey: "claude-sonnet-4-5",
expectedSuccess: true,
},
{
name: "gemini-3-pro-high should be stored as-is",
modelName: "gemini-3-pro-high",
expectedModelKey: "gemini-3-pro-high",
expectedSuccess: true,
},
{
name: "gemini-3-flash should be stored as-is",
modelName: "gemini-3-flash",
expectedModelKey: "gemini-3-flash",
expectedSuccess: true,
},
{
name: "empty model name should fail",
modelName: "",
expectedModelKey: "",
expectedSuccess: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
resetAt := time.Now().Add(30 * time.Second)
success := setModelRateLimitByModelName(
context.Background(),
repo,
123, // accountID
tt.modelName,
"[test]",
429,
resetAt,
false, // afterSmartRetry
)
require.Equal(t, tt.expectedSuccess, success)
if tt.expectedSuccess {
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
require.Equal(t, int64(123), call.accountID)
// 关键断言:存储的 key 应该是官方模型 ID,而不是 scope
require.Equal(t, tt.expectedModelKey, call.modelKey, "should store official model ID, not scope")
require.WithinDuration(t, resetAt, call.resetAt, time.Second)
} else {
require.Empty(t, repo.modelRateLimitCalls)
}
})
}
}
// TestSetModelRateLimitByModelName_NotConvertToScope 验证不会将模型名转换为 scope
func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
resetAt := time.Now().Add(30 * time.Second)
// 调用 setModelRateLimitByModelName,传入官方模型 ID
success := setModelRateLimitByModelName(
context.Background(),
repo,
456,
"claude-sonnet-4-5", // 官方模型 ID
"[test]",
429,
resetAt,
true, // afterSmartRetry
)
require.True(t, success)
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
// 关键断言:存储的应该是 "claude-sonnet-4-5",而不是 "claude_sonnet"
require.Equal(t, "claude-sonnet-4-5", call.modelKey, "should NOT convert to scope like claude_sonnet")
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
}
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
}
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 2,
Name: "acc-2",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(11 * time.Second).Format(time.RFC3339),
},
},
},
}
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
}
func TestIsAntigravityAccountSwitchError(t *testing.T) {
tests := []struct {
name string
err error
expectedOK bool
expectedID int64
expectedModel string
}{
{
name: "nil error",
err: nil,
expectedOK: false,
},
{
name: "generic error",
err: fmt.Errorf("some error"),
expectedOK: false,
},
{
name: "account switch error",
err: &AntigravityAccountSwitchError{
OriginalAccountID: 123,
RateLimitedModel: "claude-sonnet-4-5",
IsStickySession: true,
},
expectedOK: true,
expectedID: 123,
expectedModel: "claude-sonnet-4-5",
},
{
name: "wrapped account switch error",
err: fmt.Errorf("wrapped: %w", &AntigravityAccountSwitchError{
OriginalAccountID: 456,
RateLimitedModel: "gemini-3-flash",
IsStickySession: false,
}),
expectedOK: true,
expectedID: 456,
expectedModel: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switchErr, ok := IsAntigravityAccountSwitchError(tt.err)
require.Equal(t, tt.expectedOK, ok)
if tt.expectedOK {
require.NotNil(t, switchErr)
require.Equal(t, tt.expectedID, switchErr.OriginalAccountID)
require.Equal(t, tt.expectedModel, switchErr.RateLimitedModel)
} else {
require.Nil(t, switchErr)
}
})
}
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{
OriginalAccountID: 789,
RateLimitedModel: "claude-opus-4-5",
IsStickySession: true,
}
msg := err.Error()
require.Contains(t, msg, "789")
require.Contains(t, msg, "claude-opus-4-5")
}
// stubSchedulerCache 用于测试的 SchedulerCache 实现
type stubSchedulerCache struct {
SchedulerCache
setAccountCalls []*Account
setAccountErr error
}
func (s *stubSchedulerCache) SetAccount(ctx context.Context, account *Account) error {
s.setAccountCalls = append(s.setAccountCalls, account)
return s.setAccountErr
}
// TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache 测试模型限流后更新缓存
func TestUpdateAccountModelRateLimitInCache_UpdatesExtraAndCallsCache(t *testing.T) {
cache := &stubSchedulerCache{}
snapshotService := &SchedulerSnapshotService{cache: cache}
svc := &AntigravityGatewayService{
schedulerSnapshot: snapshotService,
}
account := &Account{
ID: 100,
Name: "test-account",
Platform: PlatformAntigravity,
}
modelKey := "claude-sonnet-4-5"
resetAt := time.Now().Add(30 * time.Second)
svc.updateAccountModelRateLimitInCache(context.Background(), account, modelKey, resetAt)
// 验证 Extra 字段被正确更新
require.NotNil(t, account.Extra)
limits, ok := account.Extra["model_rate_limits"].(map[string]any)
require.True(t, ok)
modelLimit, ok := limits[modelKey].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, modelLimit["rate_limited_at"])
require.NotEmpty(t, modelLimit["rate_limit_reset_at"])
// 验证 cache.SetAccount 被调用
require.Len(t, cache.setAccountCalls, 1)
require.Equal(t, account.ID, cache.setAccountCalls[0].ID)
}
// TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot 测试 schedulerSnapshot 为 nil 时不 panic
func TestUpdateAccountModelRateLimitInCache_NilSchedulerSnapshot(t *testing.T) {
svc := &AntigravityGatewayService{
schedulerSnapshot: nil,
}
account := &Account{ID: 1, Name: "test"}
// 不应 panic
svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second))
// Extra 不应被更新(因为函数提前返回)
require.Nil(t, account.Extra)
}
// TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra 测试保留已有的 Extra 数据
func TestUpdateAccountModelRateLimitInCache_PreservesExistingExtra(t *testing.T) {
cache := &stubSchedulerCache{}
snapshotService := &SchedulerSnapshotService{cache: cache}
svc := &AntigravityGatewayService{
schedulerSnapshot: snapshotService,
}
account := &Account{
ID: 200,
Name: "test-account",
Platform: PlatformAntigravity,
Extra: map[string]any{
"existing_key": "existing_value",
"model_rate_limits": map[string]any{
"gemini-3-flash": map[string]any{
"rate_limited_at": "2024-01-01T00:00:00Z",
"rate_limit_reset_at": "2024-01-01T00:05:00Z",
},
},
},
}
svc.updateAccountModelRateLimitInCache(context.Background(), account, "claude-sonnet-4-5", time.Now().Add(30*time.Second))
// 验证已有数据被保留
require.Equal(t, "existing_value", account.Extra["existing_key"])
limits := account.Extra["model_rate_limits"].(map[string]any)
require.NotNil(t, limits["gemini-3-flash"])
require.NotNil(t, limits["claude-sonnet-4-5"])
}
// TestSchedulerSnapshotService_UpdateAccountInCache 测试 UpdateAccountInCache 方法
func TestSchedulerSnapshotService_UpdateAccountInCache(t *testing.T) {
t.Run("calls cache.SetAccount", func(t *testing.T) {
cache := &stubSchedulerCache{}
svc := &SchedulerSnapshotService{cache: cache}
account := &Account{ID: 123, Name: "test"}
err := svc.UpdateAccountInCache(context.Background(), account)
require.NoError(t, err)
require.Len(t, cache.setAccountCalls, 1)
require.Equal(t, int64(123), cache.setAccountCalls[0].ID)
})
t.Run("returns nil when cache is nil", func(t *testing.T) {
svc := &SchedulerSnapshotService{cache: nil}
err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1})
require.NoError(t, err)
})
t.Run("returns nil when account is nil", func(t *testing.T) {
cache := &stubSchedulerCache{}
svc := &SchedulerSnapshotService{cache: cache}
err := svc.UpdateAccountInCache(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, cache.setAccountCalls)
})
t.Run("propagates cache error", func(t *testing.T) {
expectedErr := fmt.Errorf("cache error")
cache := &stubSchedulerCache{setAccountErr: expectedErr}
svc := &SchedulerSnapshotService{cache: cache}
err := svc.UpdateAccountInCache(context.Background(), &Account{ID: 1})
require.ErrorIs(t, err, expectedErr)
})
}
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
errors []error
callIdx int
calls []string
}
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
idx := m.callIdx
m.calls = append(m.calls, req.URL.String())
m.callIdx++
if idx < len(m.responses) {
return m.responses[idx], m.errors[idx]
}
return nil, nil
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return m.Do(req, proxyURL, accountID, accountConcurrency)
}
// TestHandleSmartRetry_URLLevelRateLimit 测试 URL 级别限流切换
func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test", "https://ag-2.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinueURL, result.action)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_LongDelay_ReturnsSwitchError 测试 retryDelay >= 阈值时返回 switchError
func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 15s >= 7s 阈值,应该返回 switchError
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for long delay")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_ShortDelay_SmartRetrySuccess 测试智能重试成功
func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 1,
Name: "acc-1",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.5s < 7s 阈值,应该触发智能重试
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.err)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 1, "should have made one retry call")
}
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
// 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次)
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp1 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp2 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp3 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp1, failResp2, failResp3},
errors: []error{nil, nil, nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 2,
Name: "acc-2",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError after smart retry failed")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-flash", result.switchError.RateLimitedModel)
require.False(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NonOAuthAccount_ContinuesDefaultLogic 测试非 OAuth 账号走默认逻辑
func TestHandleSmartRetry_NonOAuthAccount_ContinuesDefaultLogic(t *testing.T) {
account := &Account{
ID: 4,
Name: "acc-4",
Type: AccountTypeAPIKey, // 非 OAuth 账号
Platform: PlatformAntigravity,
}
// 即使是模型限流响应,非 OAuth 账号也应该走默认逻辑
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinue, result.action, "non-OAuth account should continue default logic")
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic 测试非模型限流响应走默认逻辑
func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) {
account := &Account{
ID: 5,
Name: "acc-5",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 但没有 RATE_LIMIT_EXCEEDED 或 MODEL_CAPACITY_EXHAUSTED
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "5s"}
],
"message": "Quota exceeded"
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionContinue, result.action, "non-model rate limit should continue default logic")
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError)
}
// TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError 测试刚好等于阈值时返回 switchError
func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 6,
Name: "acc-6",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 刚好 7s = 7s 阈值,应该返回 switchError
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "7s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.NotNil(t, result.switchError, "exactly at threshold should return switchError")
require.Equal(t, "gemini-pro", result.switchError.RateLimitedModel)
}
// TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates 测试 switchError 正确传播到上层
func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing.T) {
// 模拟 429 + 长延迟的响应
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "30s"}
]
}
}`)
rateLimitResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{rateLimitResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 7,
Name: "acc-7",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result when switchError")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
// 第一次网络错误,第二次成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误)
errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发
}
account := &Account{
ID: 8,
Name: "acc-8",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.1s < 7s 阈值,应该触发智能重试
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 2, "should have made two retry calls")
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 9,
Name: "acc-9",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 + RATE_LIMIT_EXCEEDED + 无 retryDelay → 使用默认 1 分钟限流
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}
],
"message": "You have exhausted your capacity on this model."
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
result := handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.NotNil(t, result.switchError, "should return switchError for no retryDelay")
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
//go:build unit
package service
import (
"testing"
)
func TestApplyThinkingModelSuffix(t *testing.T) {
tests := []struct {
name string
mappedModel string
thinkingEnabled bool
expected string
}{
// Thinking 未开启:保持原样
{
name: "thinking disabled - claude-sonnet-4-5 unchanged",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: "claude-sonnet-4-5",
},
{
name: "thinking disabled - other model unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: false,
expected: "claude-opus-4-6-thinking",
},
// Thinking 开启 + claude-sonnet-4-5:自动添加后缀
{
name: "thinking enabled - claude-sonnet-4-5 becomes thinking version",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
// Thinking 开启 + 其他模型:保持原样
{
name: "thinking enabled - claude-sonnet-4-5-thinking unchanged",
mappedModel: "claude-sonnet-4-5-thinking",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
{
name: "thinking enabled - claude-opus-4-6-thinking unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: true,
expected: "claude-opus-4-6-thinking",
},
{
name: "thinking enabled - gemini model unchanged",
mappedModel: "gemini-3-flash",
thinkingEnabled: true,
expected: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled)
if result != tt.expected {
t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q",
tt.mappedModel, tt.thinkingEnabled, result, tt.expected)
}
})
}
}
...@@ -35,6 +35,7 @@ type ConcurrencyCache interface { ...@@ -35,6 +35,7 @@ type ConcurrencyCache interface {
// 批量负载查询(只读) // 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
// 清理过期槽位(后台任务) // 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
...@@ -77,6 +78,11 @@ type AccountWithConcurrency struct { ...@@ -77,6 +78,11 @@ type AccountWithConcurrency struct {
MaxConcurrency int MaxConcurrency int
} }
type UserWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct { type AccountLoadInfo struct {
AccountID int64 AccountID int64
CurrentConcurrency int CurrentConcurrency int
...@@ -84,6 +90,13 @@ type AccountLoadInfo struct { ...@@ -84,6 +90,13 @@ type AccountLoadInfo struct {
LoadRate int // 0-100+ (percent) LoadRate int // 0-100+ (percent)
} }
type UserLoadInfo struct {
UserID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account. // AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout. // If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes. // Returns a release function that MUST be called when the request completes.
...@@ -253,6 +266,14 @@ func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts ...@@ -253,6 +266,14 @@ func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts
return s.cache.GetAccountsLoadBatch(ctx, accounts) return s.cache.GetAccountsLoadBatch(ctx, accounts)
} }
// GetUsersLoadBatch returns load info for multiple users.
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
if s.cache == nil {
return map[int64]*UserLoadInfo{}, nil
}
return s.cache.GetUsersLoadBatch(ctx, users)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task). // CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil { if s.cache == nil {
......
package service
import "github.com/gin-gonic/gin"
const errorPassthroughServiceContextKey = "error_passthrough_service"
// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。
func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) {
if c == nil || svc == nil {
return
}
c.Set(errorPassthroughServiceContextKey, svc)
}
func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService {
if c == nil {
return nil
}
v, ok := c.Get(errorPassthroughServiceContextKey)
if !ok {
return nil
}
svc, ok := v.(*ErrorPassthroughService)
if !ok {
return nil
}
return svc
}
// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。
func applyErrorPassthroughRule(
c *gin.Context,
platform string,
upstreamStatus int,
responseBody []byte,
defaultStatus int,
defaultErrType string,
defaultErrMsg string,
) (status int, errType string, errMsg string, matched bool) {
status = defaultStatus
errType = defaultErrType
errMsg = defaultErrMsg
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return status, errType, errMsg, false
}
rule := svc.MatchRule(platform, upstreamStatus, responseBody)
if rule == nil {
return status, errType, errMsg, false
}
status = upstreamStatus
if !rule.PassthroughCode && rule.ResponseCode != nil {
status = *rule.ResponseCode
}
errMsg = ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
errMsg = *rule.CustomMessage
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error"
return status, errType, errMsg, true
}
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusUnprocessableEntity,
[]byte(`{"error":{"message":"invalid schema"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.False(t, matched)
assert.Equal(t, http.StatusBadGateway, status)
assert.Equal(t, "upstream_error", errType)
assert.Equal(t, "Upstream request failed", errMsg)
}
func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &OpenAIGatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GeminiMessagesCompatService{}
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey}
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody)
require.Error(t, err)
assert.Equal(t, http.StatusBadRequest, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "invalid_request_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &GatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "上游请求失败", errField["message"])
}
func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &OpenAIGatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "OpenAI上游失败", errField["message"])
}
func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &GeminiMessagesCompatService{}
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey}
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Gemini上游失败", errField["message"])
}
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
return &model.ErrorPassthroughRule{
ID: 1,
Name: "non-failover-rule",
Enabled: true,
Priority: 1,
ErrorCodes: []int{statusCode},
Keywords: []string{keyword},
MatchMode: model.MatchModeAll,
PassthroughCode: false,
ResponseCode: &respCode,
PassthroughBody: false,
CustomMessage: &customMessage,
}
}
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
) )
...@@ -61,11 +60,8 @@ func NewErrorPassthroughService( ...@@ -61,11 +60,8 @@ func NewErrorPassthroughService(
// 启动时加载规则到本地缓存 // 启动时加载规则到本地缓存
ctx := context.Background() ctx := context.Background()
if err := svc.reloadRulesFromDB(ctx); err != nil { if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
}
} }
// 订阅缓存更新通知 // 订阅缓存更新通知
...@@ -102,9 +98,7 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP ...@@ -102,9 +98,7 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP
} }
// 刷新缓存 // 刷新缓存
refreshCtx, cancel := s.newCacheRefreshContext() s.invalidateAndNotify(ctx)
defer cancel()
s.invalidateAndNotify(refreshCtx)
return created, nil return created, nil
} }
...@@ -121,9 +115,7 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP ...@@ -121,9 +115,7 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP
} }
// 刷新缓存 // 刷新缓存
refreshCtx, cancel := s.newCacheRefreshContext() s.invalidateAndNotify(ctx)
defer cancel()
s.invalidateAndNotify(refreshCtx)
return updated, nil return updated, nil
} }
...@@ -135,9 +127,7 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error { ...@@ -135,9 +127,7 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
} }
// 刷新缓存 // 刷新缓存
refreshCtx, cancel := s.newCacheRefreshContext() s.invalidateAndNotify(ctx)
defer cancel()
s.invalidateAndNotify(refreshCtx)
return nil return nil
} }
...@@ -199,12 +189,7 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error { ...@@ -199,12 +189,7 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
} }
} }
return s.reloadRulesFromDB(ctx) // 从数据库加载(repo.List 已按 priority 排序)
}
// 从数据库加载(repo.List 已按 priority 排序)
// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。
func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
rules, err := s.repo.List(ctx) rules, err := s.repo.List(ctx)
if err != nil { if err != nil {
return err return err
...@@ -237,32 +222,11 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR ...@@ -237,32 +222,11 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR
s.localCacheMu.Unlock() s.localCacheMu.Unlock()
} }
// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。
func (s *ErrorPassthroughService) clearLocalCache() {
s.localCacheMu.Lock()
s.localCache = nil
s.localCacheMu.Unlock()
}
// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。
func (s *ErrorPassthroughService) newCacheRefreshContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 3*time.Second)
}
// invalidateAndNotify 使缓存失效并通知其他实例 // invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 先失效缓存,避免后续刷新读到陈旧规则。
if s.cache != nil {
if err := s.cache.Invalidate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err)
}
}
// 刷新本地缓存 // 刷新本地缓存
if err := s.reloadRulesFromDB(ctx); err != nil { if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
s.clearLocalCache()
} }
// 通知其他实例 // 通知其他实例
......
...@@ -4,7 +4,6 @@ package service ...@@ -4,7 +4,6 @@ package service
import ( import (
"context" "context"
"errors"
"strings" "strings"
"testing" "testing"
...@@ -15,81 +14,14 @@ import ( ...@@ -15,81 +14,14 @@ import (
// mockErrorPassthroughRepo 用于测试的 mock repository // mockErrorPassthroughRepo 用于测试的 mock repository
type mockErrorPassthroughRepo struct { type mockErrorPassthroughRepo struct {
rules []*model.ErrorPassthroughRule rules []*model.ErrorPassthroughRule
listErr error
getErr error
createErr error
updateErr error
deleteErr error
}
type mockErrorPassthroughCache struct {
rules []*model.ErrorPassthroughRule
hasData bool
getCalled int
setCalled int
invalidateCalled int
notifyCalled int
}
func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache {
return &mockErrorPassthroughCache{
rules: cloneRules(rules),
hasData: hasData,
}
}
func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
m.getCalled++
if !m.hasData {
return nil, false
}
return cloneRules(m.rules), true
}
func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
m.setCalled++
m.rules = cloneRules(rules)
m.hasData = true
return nil
}
func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error {
m.invalidateCalled++
m.rules = nil
m.hasData = false
return nil
}
func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error {
m.notifyCalled++
return nil
}
func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
// 单测中无需订阅行为
}
func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule {
if rules == nil {
return nil
}
out := make([]*model.ErrorPassthroughRule, len(rules))
copy(out, rules)
return out
} }
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
if m.listErr != nil {
return nil, m.listErr
}
return m.rules, nil return m.rules, nil
} }
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
if m.getErr != nil {
return nil, m.getErr
}
for _, r := range m.rules { for _, r := range m.rules {
if r.ID == id { if r.ID == id {
return r, nil return r, nil
...@@ -99,18 +31,12 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode ...@@ -99,18 +31,12 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode
} }
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if m.createErr != nil {
return nil, m.createErr
}
rule.ID = int64(len(m.rules) + 1) rule.ID = int64(len(m.rules) + 1)
m.rules = append(m.rules, rule) m.rules = append(m.rules, rule)
return rule, nil return rule, nil
} }
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if m.updateErr != nil {
return nil, m.updateErr
}
for i, r := range m.rules { for i, r := range m.rules {
if r.ID == rule.ID { if r.ID == rule.ID {
m.rules[i] = rule m.rules[i] = rule
...@@ -121,9 +47,6 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error ...@@ -121,9 +47,6 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error
} }
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
if m.deleteErr != nil {
return m.deleteErr
}
for i, r := range m.rules { for i, r := range m.rules {
if r.ID == id { if r.ID == id {
m.rules = append(m.rules[:i], m.rules[i+1:]...) m.rules = append(m.rules[:i], m.rules[i+1:]...)
...@@ -827,158 +750,6 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { ...@@ -827,158 +750,6 @@ func TestErrorPassthroughRule_Validate(t *testing.T) {
} }
} }
// =============================================================================
// 测试写路径缓存刷新(Create/Update/Delete)
// =============================================================================
func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) {
ctx := context.Background()
staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule})
newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败")
created, err := svc.Create(ctx, newRule)
require.NoError(t, err)
require.NotNil(t, created)
body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`)
matched := svc.MatchRule("anthropic", 503, body)
require.NotNil(t, matched)
assert.Equal(t, created.ID, matched.ID)
if assert.NotNil(t, matched.CustomMessage) {
assert.Equal(t, "上游请求失败", *matched.CustomMessage)
}
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.invalidateCalled)
assert.Equal(t, 1, cache.setCalled)
assert.Equal(t, 1, cache.notifyCalled)
}
func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) {
ctx := context.Background()
originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule})
updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息")
_, err := svc.Update(ctx, updatedRule)
require.NoError(t, err)
oldBody := []byte(`{"message":"old keyword"}`)
oldMatched := svc.MatchRule("anthropic", 503, oldBody)
assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中")
newBody := []byte(`{"message":"new keyword"}`)
newMatched := svc.MatchRule("anthropic", 503, newBody)
require.NotNil(t, newMatched)
if assert.NotNil(t, newMatched.CustomMessage) {
assert.Equal(t, "新消息", *newMatched.CustomMessage)
}
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.invalidateCalled)
assert.Equal(t, 1, cache.setCalled)
assert.Equal(t, 1, cache.notifyCalled)
}
func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) {
ctx := context.Background()
rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{rule})
err := svc.Delete(ctx, 1)
require.NoError(t, err)
body := []byte(`{"message":"to be deleted"}`)
matched := svc.MatchRule("anthropic", 503, body)
assert.Nil(t, matched, "删除后规则不应再命中")
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.invalidateCalled)
assert.Equal(t, 1, cache.setCalled)
assert.Equal(t, 1, cache.notifyCalled)
}
func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) {
staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息")
latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
svc := NewErrorPassthroughService(repo, cache)
matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`))
require.NotNil(t, matchedFresh)
assert.Equal(t, int64(1), matchedFresh.ID)
matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`))
assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存")
assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存")
}
func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) {
ctx := context.Background()
staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息")
repo := &mockErrorPassthroughRepo{
rules: []*model.ErrorPassthroughRule{staleRule},
listErr: errors.New("db list failed"),
}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule})
disabledRule := *staleRule
disabledRule.Enabled = false
_, err := svc.Update(ctx, &disabledRule)
require.NoError(t, err)
body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`)
matched := svc.MatchRule("anthropic", 503, body)
assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则")
svc.localCacheMu.RLock()
assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中")
svc.localCacheMu.RUnlock()
}
func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule {
responseCode := 503
rule := &model.ErrorPassthroughRule{
ID: id,
Name: "write-path-cache-refresh",
Enabled: true,
Priority: 1,
ErrorCodes: []int{503},
Keywords: []string{keyword},
MatchMode: model.MatchModeAll,
PassthroughCode: false,
ResponseCode: &responseCode,
PassthroughBody: false,
CustomMessage: &customMsg,
}
return rule
}
// Helper functions // Helper functions
func testIntPtr(i int) *int { return &i } func testIntPtr(i int) *int { return &i }
func testStrPtr(s string) *string { return &s } func testStrPtr(s string) *string { return &s }
//go:build unit
package service
import (
"context"
"testing"
)
func TestIsForceCacheBilling(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected bool
}{
{
name: "context without force cache billing",
ctx: context.Background(),
expected: false,
},
{
name: "context with force cache billing set to true",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, true),
expected: true,
},
{
name: "context with force cache billing set to false",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, false),
expected: false,
},
{
name: "context with wrong type value",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, "true"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsForceCacheBilling(tt.ctx)
if result != tt.expected {
t.Errorf("IsForceCacheBilling() = %v, want %v", result, tt.expected)
}
})
}
}
func TestWithForceCacheBilling(t *testing.T) {
ctx := context.Background()
// 原始上下文没有标记
if IsForceCacheBilling(ctx) {
t.Error("original context should not have force cache billing")
}
// 使用 WithForceCacheBilling 后应该有标记
newCtx := WithForceCacheBilling(ctx)
if !IsForceCacheBilling(newCtx) {
t.Error("new context should have force cache billing")
}
// 原始上下文应该不受影响
if IsForceCacheBilling(ctx) {
t.Error("original context should still not have force cache billing")
}
}
func TestForceCacheBilling_TokenConversion(t *testing.T) {
tests := []struct {
name string
forceCacheBilling bool
inputTokens int
cacheReadInputTokens int
expectedInputTokens int
expectedCacheReadTokens int
}{
{
name: "force cache billing converts input to cache_read",
forceCacheBilling: true,
inputTokens: 1000,
cacheReadInputTokens: 500,
expectedInputTokens: 0,
expectedCacheReadTokens: 1500, // 500 + 1000
},
{
name: "no force cache billing keeps tokens unchanged",
forceCacheBilling: false,
inputTokens: 1000,
cacheReadInputTokens: 500,
expectedInputTokens: 1000,
expectedCacheReadTokens: 500,
},
{
name: "force cache billing with zero input tokens does nothing",
forceCacheBilling: true,
inputTokens: 0,
cacheReadInputTokens: 500,
expectedInputTokens: 0,
expectedCacheReadTokens: 500,
},
{
name: "force cache billing with zero cache_read tokens",
forceCacheBilling: true,
inputTokens: 1000,
cacheReadInputTokens: 0,
expectedInputTokens: 0,
expectedCacheReadTokens: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 RecordUsage 中的 ForceCacheBilling 逻辑
usage := ClaudeUsage{
InputTokens: tt.inputTokens,
CacheReadInputTokens: tt.cacheReadInputTokens,
}
// 这是 RecordUsage 中的实际逻辑
if tt.forceCacheBilling && usage.InputTokens > 0 {
usage.CacheReadInputTokens += usage.InputTokens
usage.InputTokens = 0
}
if usage.InputTokens != tt.expectedInputTokens {
t.Errorf("InputTokens = %d, want %d", usage.InputTokens, tt.expectedInputTokens)
}
if usage.CacheReadInputTokens != tt.expectedCacheReadTokens {
t.Errorf("CacheReadInputTokens = %d, want %d", usage.CacheReadInputTokens, tt.expectedCacheReadTokens)
}
})
}
}
...@@ -216,6 +216,22 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context ...@@ -216,6 +216,22 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil return nil
} }
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct { type mockGroupRepoForGateway struct {
groups map[int64]*Group groups map[int64]*Group
getByIDCalls int getByIDCalls int
...@@ -332,7 +348,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing ...@@ -332,7 +348,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing
cfg: testConfig(), cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAntigravity)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, acc) require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID) require.Equal(t, int64(2), acc.ID)
...@@ -670,7 +686,7 @@ func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *tes ...@@ -670,7 +686,7 @@ func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *tes
cfg: testConfig(), cfg: testConfig(),
} }
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-sonnet-4-5", nil)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, acc) require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID) require.Equal(t, int64(2), acc.ID)
...@@ -1014,11 +1030,17 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { ...@@ -1014,11 +1030,17 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
expected bool expected bool
}{ }{
{ {
name: "Antigravity平台-支持claude模型", name: "Antigravity平台-支持默认映射中的claude模型",
account: &Account{Platform: PlatformAntigravity}, account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022", model: "claude-sonnet-4-5",
expected: true, expected: true,
}, },
{
name: "Antigravity平台-不支持非默认映射中的claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{ {
name: "Antigravity平台-支持gemini模型", name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity}, account: &Account{Platform: PlatformAntigravity},
...@@ -1115,7 +1137,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -1115,7 +1137,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(), cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, acc) require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)") require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
...@@ -1123,7 +1145,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -1123,7 +1145,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) { t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) {
groupID := int64(30) groupID := int64(30)
requestedModel := "claude-3-5-sonnet-20241022" requestedModel := "claude-sonnet-4-5"
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
...@@ -1168,7 +1190,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -1168,7 +1190,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-路由粘性命中", func(t *testing.T) { t.Run("混合调度-路由粘性命中", func(t *testing.T) {
groupID := int64(31) groupID := int64(31)
requestedModel := "claude-3-5-sonnet-20241022" requestedModel := "claude-sonnet-4-5"
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
...@@ -1320,7 +1342,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -1320,7 +1342,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
Schedulable: true, Schedulable: true,
Extra: map[string]any{ Extra: map[string]any{
"model_rate_limits": map[string]any{ "model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{ "claude-3-5-sonnet-20241022": map[string]any{
"rate_limit_reset_at": resetAt.Format(time.RFC3339), "rate_limit_reset_at": resetAt.Format(time.RFC3339),
}, },
}, },
...@@ -1465,7 +1487,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -1465,7 +1487,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(), cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, acc) require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户") require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
...@@ -1597,7 +1619,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -1597,7 +1619,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(), cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, acc) require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID) require.Equal(t, int64(1), acc.ID)
...@@ -1870,6 +1892,19 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a ...@@ -1870,6 +1892,19 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil return nil
} }
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users {
result[user.ID] = &UserLoadInfo{
UserID: user.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection // TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background() ctx := context.Background()
...@@ -2747,7 +2782,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -2747,7 +2782,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
Concurrency: 5, Concurrency: 5,
Extra: map[string]any{ Extra: map[string]any{
"model_rate_limits": map[string]any{ "model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{ "claude-3-5-sonnet-20241022": map[string]any{
"rate_limit_reset_at": now.Format(time.RFC3339), "rate_limit_reset_at": now.Format(time.RFC3339),
}, },
}, },
......
...@@ -4,6 +4,8 @@ import ( ...@@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
) )
// ParsedRequest 保存网关请求的预解析结果 // ParsedRequest 保存网关请求的预解析结果
...@@ -19,13 +21,14 @@ import ( ...@@ -19,13 +21,14 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层 // 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct { type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发) Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称 Model string // 请求的模型名称
Stream bool // 是否为流式请求 Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和) MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容 System any // system 字段内容
Messages []any // messages 数组 Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
} }
// ParseGatewayRequest 解析网关请求体并返回结构化结果 // ParseGatewayRequest 解析网关请求体并返回结构化结果
...@@ -69,6 +72,13 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { ...@@ -69,6 +72,13 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.Messages = messages parsed.Messages = messages
} }
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
parsed.ThinkingEnabled = true
}
}
return parsed, nil return parsed, nil
} }
...@@ -466,7 +476,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte { ...@@ -466,7 +476,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// only keep thinking blocks with valid signatures // only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" { if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string) signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" { if signature != "" && signature != antigravity.DummyThoughtSignature {
newContent = append(newContent, block) newContent = append(newContent, block)
continue continue
} }
......
...@@ -17,6 +17,15 @@ func TestParseGatewayRequest(t *testing.T) { ...@@ -17,6 +17,15 @@ func TestParseGatewayRequest(t *testing.T) {
require.True(t, parsed.HasSystem) require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System) require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1) require.Len(t, parsed.Messages, 1)
require.False(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
} }
func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_SystemNull(t *testing.T) {
......
...@@ -22,6 +22,7 @@ import ( ...@@ -22,6 +22,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
...@@ -49,6 +50,29 @@ const ( ...@@ -49,6 +50,29 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info" claudeMimicDebugInfoKey = "claude_mimic_debug_info"
) )
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
// IsForceCacheBilling 检查是否启用强制缓存计费
func IsForceCacheBilling(ctx context.Context) bool {
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
return v
}
// WithForceCacheBilling 返回带有强制缓存计费标记的上下文
func WithForceCacheBilling(ctx context.Context) context.Context {
return context.WithValue(ctx, ForceCacheBillingContextKey, true)
}
func (s *GatewayService) debugModelRoutingEnabled() bool { func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on" return v == "1" || v == "true" || v == "yes" || v == "on"
...@@ -250,6 +274,13 @@ var allowedHeaders = map[string]bool{ ...@@ -250,6 +274,13 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。 // GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
// //
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service. // GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities. // Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface { type GatewayCache interface {
...@@ -265,6 +296,24 @@ type GatewayCache interface { ...@@ -265,6 +296,24 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable // Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息(uuid, accountID)
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
} }
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil // derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...@@ -275,16 +324,23 @@ func derefGroupID(groupID *int64) int64 { ...@@ -275,16 +324,23 @@ func derefGroupID(groupID *int64) int64 {
return *groupID return *groupID
} }
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。 // 这确保后续请求不会继续使用不可用的账号。
// //
// shouldClearStickySession checks if an account is in an unschedulable state // shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared. // and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false, // Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period. // within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// This ensures subsequent requests won't continue using unavailable accounts. // This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool { func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil { if account == nil {
return false return false
} }
...@@ -294,6 +350,10 @@ func shouldClearStickySession(account *Account) bool { ...@@ -294,6 +350,10 @@ func shouldClearStickySession(account *Account) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true return true
} }
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
return true
}
return false return false
} }
...@@ -336,8 +396,9 @@ type ForwardResult struct { ...@@ -336,8 +396,9 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct { type UpstreamFailoverError struct {
StatusCode int StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配 ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
} }
func (e *UpstreamFailoverError) Error() string { func (e *UpstreamFailoverError) Error() string {
...@@ -470,6 +531,23 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID ...@@ -470,6 +531,23 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
return accountID, nil return accountID, nil
} }
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息(uuid, accountID)
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil { if parsed == nil {
return "" return ""
...@@ -968,6 +1046,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -968,6 +1046,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 1. 过滤出路由列表中可调度的账号 // 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID
for _, routingAccountID := range routingAccountIDs { for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) { if isExcluded(routingAccountID) {
filteredExcluded++ filteredExcluded++
...@@ -986,12 +1065,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -986,12 +1065,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredPlatform++ filteredPlatform++
continue continue
} }
if !account.IsSchedulableForModel(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) {
filteredModelScope++ filteredModelMapping++
continue continue
} }
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) { if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
filteredModelMapping++ filteredModelScope++
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue continue
} }
// 窗口费用检查(非粘性会话路径) // 窗口费用检查(非粘性会话路径)
...@@ -1006,6 +1086,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1006,6 +1086,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
if len(modelScopeSkippedIDs) > 0 {
log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v",
derefGroupID(groupID), requestedModel, modelScopeSkippedIDs)
}
} }
if len(routingCandidates) > 0 { if len(routingCandidates) > 0 {
...@@ -1017,8 +1101,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1017,8 +1101,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() && if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) && stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
...@@ -1075,10 +1159,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1075,10 +1159,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
// 3. 按负载感知排序 // 3. 按负载感知排序
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var routingAvailable []accountWithLoad var routingAvailable []accountWithLoad
for _, acc := range routingCandidates { for _, acc := range routingCandidates {
loadInfo := routingLoadMap[acc.ID] loadInfo := routingLoadMap[acc.ID]
...@@ -1169,14 +1249,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1169,14 +1249,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if ok { if ok {
// 检查账户是否需要清理粘性会话绑定 // 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup // Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
...@@ -1234,10 +1314,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1234,10 +1314,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue continue
} }
// 窗口费用检查(非粘性会话路径) // 窗口费用检查(非粘性会话路径)
...@@ -1265,10 +1345,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1265,10 +1345,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil return result, nil
} }
} else { } else {
type accountWithLoad struct { // Antigravity 平台:获取模型负载信息
account *Account var modelLoadMap map[int64]*ModelLoadInfo
loadInfo *AccountLoadInfo isAntigravity := platform == PlatformAntigravity
}
var available []accountWithLoad var available []accountWithLoad
for _, acc := range candidates { for _, acc := range candidates {
loadInfo := loadMap[acc.ID] loadInfo := loadMap[acc.ID]
...@@ -1283,47 +1363,108 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1283,47 +1363,108 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
if len(available) > 0 { // Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
sort.SliceStable(available, func(i, j int) bool { if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
a, b := available[i], available[j] modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
if a.account.Priority != b.account.Priority { modelToAccountIDs := make(map[string][]int64)
return a.account.Priority < b.account.Priority for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
} }
if a.loadInfo.LoadRate != b.loadInfo.LoadRate { modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
return a.loadInfo.LoadRate < b.loadInfo.LoadRate }
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
} }
switch { for id, info := range batch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: modelLoadMap[id] = info
return true }
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: }
return false if len(modelLoadMap) == 0 {
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: modelLoadMap = nil
if preferOAuth && a.account.Type != b.account.Type { }
return a.account.Type == AccountTypeOAuth }
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
} }
})
for _, item := range available { // 移除已尝试的账号,重新选择
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue } else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
if sessionHash != "" && s.cache != nil { }
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
} }
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
available = newAvailable
} }
} }
} }
...@@ -1740,6 +1881,106 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in ...@@ -1740,6 +1881,106 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
// filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minPriority := accounts[0].account.Priority
for _, acc := range accounts[1:] {
if acc.account.Priority < minPriority {
minPriority = acc.account.Priority
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.account.Priority == minPriority {
result = append(result, acc)
}
}
return result
}
// filterByMinLoadRate 过滤出负载率最低的账号集合
func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minLoadRate := accounts[0].loadInfo.LoadRate
for _, acc := range accounts[1:] {
if acc.loadInfo.LoadRate < minLoadRate {
minLoadRate = acc.loadInfo.LoadRate
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.loadInfo.LoadRate == minLoadRate {
result = append(result, acc)
}
}
return result
}
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 1. 找到最小的 LastUsedAt(nil 被视为最小)
var minTime *time.Time
hasNil := false
for _, acc := range accounts {
if acc.account.LastUsedAt == nil {
hasNil = true
break
}
if minTime == nil || acc.account.LastUsedAt.Before(*minTime) {
minTime = acc.account.LastUsedAt
}
}
// 2. 收集所有具有最小 LastUsedAt 的账号索引
var candidateIdxs []int
for i, acc := range accounts {
if hasNil {
if acc.account.LastUsedAt == nil {
candidateIdxs = append(candidateIdxs, i)
}
} else {
if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) {
candidateIdxs = append(candidateIdxs, i)
}
}
}
// 3. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 5. 随机选择一个
selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))]
return &accounts[selectedIdx]
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool { sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j] a, b := accounts[i], accounts[j]
...@@ -1762,6 +2003,87 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { ...@@ -1762,6 +2003,87 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
}) })
} }
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
}
// sortCandidatesForFallback 根据配置选择排序策略 // sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机) // mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
...@@ -1843,11 +2165,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1843,11 +2165,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil { if err == nil {
clearSticky := shouldClearStickySession(account) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
} }
...@@ -1894,10 +2216,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1894,10 +2216,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() { if !acc.IsSchedulable() {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue continue
} }
if selected == nil { if selected == nil {
...@@ -1946,11 +2268,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1946,11 +2268,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil { if err == nil {
clearSticky := shouldClearStickySession(account) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
} }
...@@ -1986,10 +2308,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1986,10 +2308,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() { if !acc.IsSchedulable() {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue continue
} }
if selected == nil { if selected == nil {
...@@ -2056,11 +2378,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2056,11 +2378,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil { if err == nil {
clearSticky := shouldClearStickySession(account) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
...@@ -2109,10 +2431,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2109,10 +2431,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue continue
} }
if selected == nil { if selected == nil {
...@@ -2161,11 +2483,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2161,11 +2483,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID) account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil { if err == nil {
clearSticky := shouldClearStickySession(account) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
...@@ -2203,10 +2525,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2203,10 +2525,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue continue
} }
if selected == nil { if selected == nil {
...@@ -2250,11 +2572,44 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2250,11 +2572,44 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
return selected, nil return selected, nil
} }
// isModelSupportedByAccount 根据账户平台检查模型支持 // isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查 // Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel) if strings.TrimSpace(requestedModel) == "" {
return true
}
if !IsAntigravityModelSupported(requestedModel) {
return false
}
// 先用默认映射获取基础模型名,再应用 thinking 后缀
defaultMapped, exists := domain.DefaultAntigravityModelMapping[requestedModel]
if !exists || defaultMapped == "" {
return false
}
finalModel := defaultMapped
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
finalModel = applyThinkingModelSuffix(finalModel, enabled)
}
// 使用最终模型名检查 model_mapping 支持
return account.IsModelSupported(finalModel)
}
return s.isModelSupportedByAccount(account, requestedModel)
}
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 应使用 isModelSupportedByAccountWithContext
// 这里作为兼容保留,使用原始模型名检查
if strings.TrimSpace(requestedModel) == "" {
return true
}
if !IsAntigravityModelSupported(requestedModel) {
return false
}
return account.IsModelSupported(requestedModel)
} }
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
...@@ -2269,10 +2624,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo ...@@ -2269,10 +2624,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
} }
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 // IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持 // 只有在默认映射(DefaultAntigravityModelMapping)中配置的模型才被支持
func IsAntigravityModelSupported(requestedModel string) bool { func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") || // 检查是否在默认映射的 key 中
strings.HasPrefix(requestedModel, "gemini-") _, exists := domain.DefaultAntigravityModelMapping[requestedModel]
return exists
} }
// GetAccessToken 获取账号凭证 // GetAccessToken 获取账号凭证
...@@ -3563,34 +3919,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -3563,34 +3919,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
) )
} }
// 非 failover 错误也支持错误透传规则匹配。
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息) // 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string var errType, errMsg string
var statusCode int var statusCode int
...@@ -3722,33 +4050,6 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ...@@ -3722,33 +4050,6 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
) )
} }
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
respBody,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed after retries",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 返回统一的重试耗尽错误响应 // 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{ c.JSON(http.StatusBadGateway, gin.H{
"type": "error", "type": "error",
...@@ -4162,14 +4463,15 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo ...@@ -4162,14 +4463,15 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
} }
// APIKeyQuotaUpdater defines the interface for updating API Key quota // APIKeyQuotaUpdater defines the interface for updating API Key quota
...@@ -4185,6 +4487,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4185,6 +4487,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
...@@ -4345,6 +4656,7 @@ type RecordUsageLongContextInput struct { ...@@ -4345,6 +4656,7 @@ type RecordUsageLongContextInput struct {
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000) LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService *APIKeyService // API Key 配额服务(可选) APIKeyService *APIKeyService // API Key 配额服务(可选)
} }
...@@ -4356,6 +4668,15 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4356,6 +4668,15 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestGatewayService_isModelSupportedByAccount_AntigravityModelMapping(t *testing.T) {
svc := &GatewayService{}
// 使用 model_mapping 作为白名单(通配符匹配)
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
"gemini-3-*": "gemini-3-flash",
},
},
}
// claude-* 通配符匹配
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-opus-4-6"))
// gemini-3-* 通配符匹配
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-pro-high"))
// gemini-2.5-* 不匹配(不在 model_mapping 中)
require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-flash"))
require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro"))
// 其他平台模型不支持
require.False(t, svc.isModelSupportedByAccount(account, "gpt-4"))
// 空模型允许
require.True(t, svc.isModelSupportedByAccount(account, ""))
}
func TestGatewayService_isModelSupportedByAccount_AntigravityNoMapping(t *testing.T) {
svc := &GatewayService{}
// 未配置 model_mapping 时,使用默认映射(domain.DefaultAntigravityModelMapping)
// 只有默认映射中的模型才被支持
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{},
}
// 默认映射中的模型应该被支持
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5"))
// 不在默认映射中的模型不被支持
require.False(t, svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022"))
require.False(t, svc.isModelSupportedByAccount(account, "claude-unknown-model"))
// 非 claude-/gemini- 前缀仍然不支持
require.False(t, svc.isModelSupportedByAccount(account, "gpt-4"))
}
// TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode 测试 thinking 模式下的模型支持检查
// 验证调度时使用映射后的最终模型名(包括 thinking 后缀)来检查 model_mapping 支持
func TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
modelMapping map[string]any
requestedModel string
thinkingEnabled bool
expected bool
}{
// 场景 1: 配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=true
// 最终模型名 = claude-sonnet-4-5-thinking,应该匹配
{
name: "thinking_enabled_matches_thinking_model",
modelMapping: map[string]any{
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true,
},
// 场景 2: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=false
// 最终模型名 = claude-sonnet-4-5,不在 mapping 中,应该不匹配
{
name: "thinking_disabled_no_match_thinking_only_mapping",
modelMapping: map[string]any{
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: false,
},
// 场景 3: 配置 claude-sonnet-4-5(非 thinking),请求 claude-sonnet-4-5 + thinking=true
// 最终模型名 = claude-sonnet-4-5-thinking,不在 mapping 中,应该不匹配
{
name: "thinking_enabled_no_match_non_thinking_mapping",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: false,
},
// 场景 4: 配置两种模型,请求 claude-sonnet-4-5 + thinking=true,应该匹配 thinking 版本
{
name: "both_models_thinking_enabled_matches_thinking",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true,
},
// 场景 5: 配置两种模型,请求 claude-sonnet-4-5 + thinking=false,应该匹配非 thinking 版本
{
name: "both_models_thinking_disabled_matches_non_thinking",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: true,
},
// 场景 6: 通配符 claude-* 应该同时匹配 thinking 和非 thinking
{
name: "wildcard_matches_thinking",
modelMapping: map[string]any{
"claude-*": "claude-sonnet-4-5",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true, // claude-sonnet-4-5-thinking 匹配 claude-*
},
// 场景 7: 其他模型(非 sonnet-4-5)的 thinking 不受影响
{
name: "opus_thinking_unchanged",
modelMapping: map[string]any{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
},
requestedModel: "claude-opus-4-6",
thinkingEnabled: true,
expected: true, // claude-opus-4-6 映射到 claude-opus-4-6-thinking,匹配
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": tt.modelMapping,
},
}
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, tt.thinkingEnabled)
result := svc.isModelSupportedByAccountWithContext(ctx, account, tt.requestedModel)
require.Equal(t, tt.expected, result,
"isModelSupportedByAccountWithContext(ctx[thinking=%v], account, %q) = %v, want %v",
tt.thinkingEnabled, tt.requestedModel, result, tt.expected)
})
}
}
...@@ -200,7 +200,7 @@ func (s *GeminiMessagesCompatService) tryStickySessionHit( ...@@ -200,7 +200,7 @@ func (s *GeminiMessagesCompatService) tryStickySessionHit(
// 检查账号是否需要清理粘性会话 // 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared // Check if sticky session should be cleared
if shouldClearStickySession(account) { if shouldClearStickySession(account, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil return nil
} }
...@@ -230,7 +230,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( ...@@ -230,7 +230,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
) bool { ) bool {
// 检查模型调度能力 // 检查模型调度能力
// Check model scheduling capability // Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) { if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false return false
} }
...@@ -1498,28 +1498,6 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc ...@@ -1498,28 +1498,6 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
} }
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformGemini,
upstreamStatus,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus)
}
return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg)
}
var statusCode int var statusCode int
var errType, errMsg string var errType, errMsg string
...@@ -2658,7 +2636,9 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 { ...@@ -2658,7 +2636,9 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
if meta, ok := dm["metadata"].(map[string]any); ok { if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok { if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil { if dur, err := time.ParseDuration(v); err == nil {
ts := time.Now().Unix() + int64(dur.Seconds()) // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
// which can affect scheduling decisions around thresholds (like 10s).
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts return &ts
} }
} }
......
...@@ -265,6 +265,22 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, ...@@ -265,6 +265,22 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil return nil
} }
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background() ctx := context.Background()
...@@ -880,7 +896,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { ...@@ -880,7 +896,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
{ {
name: "Antigravity平台-支持claude模型", name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity}, account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022", model: "claude-sonnet-4-5",
expected: true, expected: true,
}, },
{ {
......
package service
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
h := xxhash.Sum64(data)
return strconv.FormatUint(h, 36)
}
// BuildGeminiDigestChain 根据 Gemini 请求生成摘要链
// 格式: s:<hash>-u:<hash>-m:<hash>-u:<hash>-...
// s = systemInstruction, u = user, m = model
func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
if req == nil {
return ""
}
var parts []string
// 1. system instruction
if req.SystemInstruction != nil && len(req.SystemInstruction.Parts) > 0 {
partsData, _ := json.Marshal(req.SystemInstruction.Parts)
parts = append(parts, "s:"+shortHash(partsData))
}
// 2. contents
for _, c := range req.Contents {
prefix := "u" // user
if c.Role == "model" {
prefix = "m"
}
partsData, _ := json.Marshal(c.Parts)
parts = append(parts, prefix+":"+shortHash(partsData))
}
return strings.Join(parts, "-")
}
// GenerateGeminiPrefixHash 生成前缀 hash(用于分区隔离)
// 组合: userID + apiKeyID + ip + userAgent + platform + model
// 返回 16 字符的 Base64 编码的 SHA256 前缀
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
// 组合所有标识符
combined := strconv.FormatInt(userID, 10) + ":" +
strconv.FormatInt(apiKeyID, 10) + ":" +
ip + ":" +
userAgent + ":" +
platform + ":" +
model
hash := sha256.Sum256([]byte(combined))
// 取前 12 字节,Base64 编码后正好 16 字符
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
if value == "" {
return "", 0, false
}
// 找到最后一个 ":" 的位置(因为 uuid 可能包含 ":")
i := strings.LastIndex(value, ":")
if i <= 0 || i >= len(value)-1 {
return "", 0, false
}
uuid = value[:i]
accountID, err := strconv.ParseInt(value[i+1:], 10, 64)
if err != nil {
return "", 0, false
}
return uuid, accountID, true
}
// FormatGeminiSessionValue 格式化 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func FormatGeminiSessionValue(uuid string, accountID int64) string {
return uuid + ":" + strconv.FormatInt(accountID, 10)
}
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
func GenerateGeminiDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return geminiDigestSessionKeyPrefix + prefix + ":" + uuidPart
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment