Commit 5e98445b authored by erio's avatar erio
Browse files

feat(antigravity): comprehensive enhancements - model mapping, rate limiting, scheduling & ops

Key changes:
- Upgrade model mapping: Opus 4.5 → Opus 4.6-thinking with precise matching
- Unified rate limiting: scope-level → model-level with Redis snapshot sync
- Load-balanced scheduling by call count with smart retry mechanism
- Force cache billing support
- Model identity injection in prompts with leak prevention
- Thinking mode auto-handling (max_tokens/budget_tokens fix)
- Frontend: whitelist mode toggle, model mapping validation, status indicators
- Gemini session fallback with Redis Trie O(L) matching
- Ops: enhanced concurrency monitoring, account availability, retry logic
- Migration scripts: 049-051 for model mapping unification
parent e617b45b
......@@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
......@@ -113,7 +114,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-5",
"model": "claude-opus-4-6",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
......@@ -149,7 +150,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
},
}
result, err := svc.Forward(context.Background(), c, account, body)
result, err := svc.Forward(context.Background(), c, account, body, false)
require.Nil(t, result)
var promptErr *PromptTooLongError
......@@ -166,27 +167,227 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
require.Equal(t, "prompt_too_long", events[0].Kind)
}
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "4")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
// TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover
// 验证:当账号存在模型限流且剩余时间 >= antigravityRateLimitThreshold 时,
// Forward 方法应返回 UpstreamFailoverError,触发 Handler 切换账号
func TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 1,
Name: "acc-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
require.Equal(t, 4, got)
result, err := svc.Forward(context.Background(), c, account, body, false)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
require.Equal(t, 7, got)
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求,ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "5")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
// TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover
// 验证:ForwardGemini 方法同样能正确将 AntigravityAccountSwitchError 转换为 UpstreamFailoverError
func TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 2,
Name: "acc-gemini-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求,ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
// TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling
// 验证:粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]string{{"role": "user", "content": "hello"}},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 3,
Name: "acc-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.Forward(context.Background(), c, account, body, true)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 4,
Name: "acc-gemini-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, true)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
require.Equal(t, 5, got)
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
......@@ -14,32 +14,28 @@ func TestIsAntigravityModelSupported(t *testing.T) {
model string
expected bool
}{
// 直接支持的模型
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
// 可映射的模型
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// 在默认映射中的模型(支持)
{"默认映射 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"默认映射 - claude-opus-4-6-thinking", "claude-opus-4-6-thinking", true},
{"默认映射 - claude-opus-4-6", "claude-opus-4-6", true},
{"默认映射 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"默认映射 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"默认映射 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"默认映射 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"默认映射 - gemini-3-pro-high", "gemini-3-pro-high", true},
{"默认映射 - claude-haiku-4-5", "claude-haiku-4-5", true},
// Claude 前缀兜底
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
{"Claude前缀 - claude-future-version", "claude-future-version", true},
// 不在默认映射中的模型(不支持)
{"未配置 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", false},
{"未配置 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", false},
{"未配置 - claude-3-haiku-20240307", "claude-3-haiku-20240307", false},
{"未配置 - gemini-unknown-model", "gemini-unknown-model", false},
{"未配置 - gemini-future-version", "gemini-future-version", false},
{"未配置 - claude-unknown-model", "claude-unknown-model", false},
{"未配置 - claude-3-opus-20240229", "claude-3-opus-20240229", false},
{"未配置 - claude-future-version", "claude-future-version", false},
// 不支持的模型
// 非 Claude/Gemini 模型(不支持)
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false},
......@@ -64,7 +60,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
// 1. 账户级映射优先
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
......@@ -72,120 +68,124 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "custom-model",
},
{
name: "账户映射覆盖系统映射",
name: "账户映射 - 可覆盖默认映射的模型",
requestedModel: "claude-sonnet-4-5",
accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"},
expected: "my-custom-sonnet",
},
{
name: "账户映射 - 可覆盖未知模型",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 系统默认映射
// 2. 默认映射(DefaultAntigravityModelMapping)
{
name: "系统映射 - claude-3-5-sonnet-20241022",
requestedModel: "claude-3-5-sonnet-20241022",
name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-3-5-sonnet-20240620",
requestedModel: "claude-3-5-sonnet-20240620",
name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-opus-4",
requestedModel: "claude-opus-4",
name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-opus-4-5-20251101",
requestedModel: "claude-opus-4-5-20251101",
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4",
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
// 3. 默认映射中的透传(映射到自己)
{
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
requestedModel: "claude-3-haiku-20240307",
name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
name: "默认映射透传 - claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-opus-4-6-thinking",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
name: "默认映射透传 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-thinking",
},
// 3. Gemini 2.5 → 3 映射
{
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
name: "默认映射透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "gemini-2.5-flash",
},
{
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
name: "默认映射透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-3-pro-high",
expected: "gemini-2.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
requestedModel: "gemini-future-model",
name: "默认映射透传 - gemini-3-flash",
requestedModel: "gemini-3-flash",
accountMapping: nil,
expected: "gemini-future-model",
expected: "gemini-3-flash",
},
// 4. 直接支持的模型
// 4. 未在默认映射中的模型返回空字符串(不支持)
{
name: "直接支持 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
name: "未知模型 - claude-unknown 返回空",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
{
name: "直接支持 - claude-opus-4-5-thinking",
requestedModel: "claude-opus-4-5-thinking",
name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
expected: "",
},
{
name: "直接支持 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
name: "未知模型 - claude-3-opus-20240229 返回空",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "",
},
// 5. 默认值 fallback(未知 claude 模型)
{
name: "默认值 - claude-unknown",
requestedModel: "claude-unknown",
name: "未知模型 - claude-opus-4 返回空",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
{
name: "默认值 - claude-3-opus-20240229",
requestedModel: "claude-3-opus-20240229",
name: "未知模型 - gemini-future-model 返回空",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "",
},
}
......@@ -219,12 +219,10 @@ func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
requestedModel string
expected string
}{
// 空字符串回退到默认值
{"空字符串", "", "claude-sonnet-4-5"},
// 非 claude/gemini 前缀回退到默认值
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
// 空字符串和非 claude/gemini 前缀返回空字符串
{"空字符串", "", ""},
{"非claude/gemini前缀 - gpt", "gpt-4", ""},
{"非claude/gemini前缀 - llama", "llama-3", ""},
}
for _, tt := range tests {
......@@ -248,10 +246,10 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射
{"可映射 - claude-opus-4", "claude-opus-4", true},
// 可映射(有明确前缀映射)
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
// 前缀透传
// 前缀透传(claude 和 gemini 前缀)
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
......
package service
import (
"context"
"slices"
"strings"
"time"
......@@ -57,15 +58,20 @@ func normalizeAntigravityModelName(model string) string {
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
}
func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.isModelRateLimited(requestedModel) {
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
......@@ -132,3 +138,43 @@ func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
}
This diff is collapsed.
//go:build unit
package service
import (
"testing"
)
func TestApplyThinkingModelSuffix(t *testing.T) {
tests := []struct {
name string
mappedModel string
thinkingEnabled bool
expected string
}{
// Thinking 未开启:保持原样
{
name: "thinking disabled - claude-sonnet-4-5 unchanged",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: "claude-sonnet-4-5",
},
{
name: "thinking disabled - other model unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: false,
expected: "claude-opus-4-6-thinking",
},
// Thinking 开启 + claude-sonnet-4-5:自动添加后缀
{
name: "thinking enabled - claude-sonnet-4-5 becomes thinking version",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
// Thinking 开启 + 其他模型:保持原样
{
name: "thinking enabled - claude-sonnet-4-5-thinking unchanged",
mappedModel: "claude-sonnet-4-5-thinking",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
{
name: "thinking enabled - claude-opus-4-6-thinking unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: true,
expected: "claude-opus-4-6-thinking",
},
{
name: "thinking enabled - gemini model unchanged",
mappedModel: "gemini-3-flash",
thinkingEnabled: true,
expected: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled)
if result != tt.expected {
t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q",
tt.mappedModel, tt.thinkingEnabled, result, tt.expected)
}
})
}
}
......@@ -35,6 +35,7 @@ type ConcurrencyCache interface {
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
......@@ -77,6 +78,11 @@ type AccountWithConcurrency struct {
MaxConcurrency int
}
type UserWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
......@@ -84,6 +90,13 @@ type AccountLoadInfo struct {
LoadRate int // 0-100+ (percent)
}
type UserLoadInfo struct {
UserID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
......@@ -253,6 +266,14 @@ func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// GetUsersLoadBatch returns load info for multiple users.
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
if s.cache == nil {
return map[int64]*UserLoadInfo{}, nil
}
return s.cache.GetUsersLoadBatch(ctx, users)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
......
package service
import "github.com/gin-gonic/gin"
const errorPassthroughServiceContextKey = "error_passthrough_service"
// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。
func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) {
if c == nil || svc == nil {
return
}
c.Set(errorPassthroughServiceContextKey, svc)
}
func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService {
if c == nil {
return nil
}
v, ok := c.Get(errorPassthroughServiceContextKey)
if !ok {
return nil
}
svc, ok := v.(*ErrorPassthroughService)
if !ok {
return nil
}
return svc
}
// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。
func applyErrorPassthroughRule(
c *gin.Context,
platform string,
upstreamStatus int,
responseBody []byte,
defaultStatus int,
defaultErrType string,
defaultErrMsg string,
) (status int, errType string, errMsg string, matched bool) {
status = defaultStatus
errType = defaultErrType
errMsg = defaultErrMsg
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return status, errType, errMsg, false
}
rule := svc.MatchRule(platform, upstreamStatus, responseBody)
if rule == nil {
return status, errType, errMsg, false
}
status = upstreamStatus
if !rule.PassthroughCode && rule.ResponseCode != nil {
status = *rule.ResponseCode
}
errMsg = ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
errMsg = *rule.CustomMessage
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error"
return status, errType, errMsg, true
}
This diff is collapsed.
......@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ParsedRequest 保存网关请求的预解析结果
......@@ -26,6 +28,7 @@ type ParsedRequest struct {
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
......@@ -69,6 +72,13 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.Messages = messages
}
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
parsed.ThinkingEnabled = true
}
}
return parsed, nil
}
......@@ -466,7 +476,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" {
if signature != "" && signature != antigravity.DummyThoughtSignature {
newContent = append(newContent, block)
continue
}
......
......@@ -17,6 +17,15 @@ func TestParseGatewayRequest(t *testing.T) {
require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1)
require.False(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
......
This diff is collapsed.
......@@ -265,6 +265,22 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil
}
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
......@@ -880,7 +896,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
model: "claude-sonnet-4-5",
expected: true,
},
{
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment