"backend/internal/pkg/vscode:/vscode.git/clone" did not exist on "0fe09f1d40b128d8effa6f145bb8ab62d64d2c6f"
Unverified Commit 9e3c306a authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #513 from touwaeriol/pr/antigravity-full-v2

feat(antigravity): comprehensive enhancements — rate limiting, scheduling & smart retry
parents b1c30df8 3077fd27
...@@ -238,3 +238,41 @@ func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, pre ...@@ -238,3 +238,41 @@ func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, pre
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
} }
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}
package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
// Anthropic 会话 Fallback 相关常量
const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
anthropicTrieKeyPrefix = "anthropic:trie:"
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
func AnthropicSessionTTL() time.Duration {
return anthropicSessionTTLSeconds * time.Second
}
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
// s = system, u = user, a = assistant
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var parts []string
// 1. system prompt
if parsed.System != nil {
systemData, _ := json.Marshal(parsed.System)
if len(systemData) > 0 && string(systemData) != "null" {
parts = append(parts, "s:"+shortHash(systemData))
}
}
// 2. messages
for _, msg := range parsed.Messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
prefix := rolePrefix(role)
content := msgMap["content"]
contentData, _ := json.Marshal(content)
parts = append(parts, prefix+":"+shortHash(contentData))
}
return strings.Join(parts, "-")
}
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
func rolePrefix(role string) string {
switch role {
case "assistant":
return "a"
default:
return "u"
}
}
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
// 格式: anthropic:trie:{groupID}:{prefixHash}
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
}
package service
import (
"strings"
"testing"
)
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
result := BuildAnthropicDigestChain(nil)
if result != "" {
t.Errorf("expected empty string for nil request, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{},
}
result := BuildAnthropicDigestChain(parsed)
if result != "" {
t.Errorf("expected empty string for empty messages, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "a:") {
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
parsed := &ParsedRequest{
System: "You are a helpful assistant",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "u:") {
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
parsed := &ParsedRequest{
System: []any{
map[string]any{"type": "text", "text": "You are a helpful assistant"},
},
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
// 核心测试:验证对话增长时链的前缀关系
// 上一轮的完整链一定是下一轮链的前缀
system := "You are a helpful assistant"
// 第 1 轮: system + user
round1 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(round1)
// 第 2 轮: system + user + assistant + user
round2 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
},
}
chain2 := BuildAnthropicDigestChain(round2)
// 第 3 轮: system + user + assistant + user + assistant + user
round3 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
map[string]any{"role": "assistant", "content": "I'm doing well"},
map[string]any{"role": "user", "content": "great"},
},
}
chain3 := BuildAnthropicDigestChain(round3)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
t.Logf("Chain3: %s", chain3)
// chain1 是 chain2 的前缀
if !strings.HasPrefix(chain2, chain1) {
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
}
// chain2 是 chain3 的前缀
if !strings.HasPrefix(chain3, chain2) {
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
}
// chain1 也是 chain3 的前缀(传递性)
if !strings.HasPrefix(chain3, chain1) {
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
}
}
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
System: "System A",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
parsed2 := &ParsedRequest{
System: "System B",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different system prompts should produce different chains")
}
// 但 user 部分的 hash 应该相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[1] != parts2[1] {
t.Error("Same user message should produce same hash regardless of system")
}
}
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
map[string]any{"role": "user", "content": "next"},
},
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
map[string]any{"role": "user", "content": "next"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different content should produce different chains")
}
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
// 第一个 user message hash 应该相同
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
// assistant reply hash 应该不同
if parts1[1] == parts2[1] {
t.Error("Assistant reply hash should differ")
}
}
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
System: "test system",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi"},
},
}
chain1 := BuildAnthropicDigestChain(parsed)
chain2 := BuildAnthropicDigestChain(parsed)
if chain1 != chain2 {
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
}
}
func TestBuildAnthropicTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "anthropic:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "anthropic:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "anthropic:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "anthropic:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "anthropic:digest:12345678:abcdefgh",
},
{
name: "short values",
prefixHash: "abc",
uuid: "xyz",
want: "anthropic:digest:abc:xyz",
},
{
name: "empty values",
prefixHash: "",
uuid: "",
want: "anthropic:digest::",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证不同 uuid 产生不同 sessionKey
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestAnthropicSessionTTL(t *testing.T) {
ttl := AnthropicSessionTTL()
if ttl.Seconds() != 300 {
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
}
}
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
// 测试 content 为 content blocks 数组的情况
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe this image"},
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
},
},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
...@@ -35,7 +35,7 @@ const ( ...@@ -35,7 +35,7 @@ const (
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号 // - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
antigravityRateLimitThreshold = 7 * time.Second antigravityRateLimitThreshold = 7 * time.Second
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间 antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数 antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// Google RPC 状态和类型常量 // Google RPC 状态和类型常量
...@@ -247,6 +247,11 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -247,6 +247,11 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
} }
// 清除粘性会话绑定,避免下次请求仍命中限流账号
if s.cache != nil && p.sessionHash != "" {
_ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
}
// 返回账号切换信号,让上层切换账号重试 // 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{ return &smartRetryResult{
action: smartRetryActionBreakWithResp, action: smartRetryActionBreakWithResp,
...@@ -264,27 +269,15 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam ...@@ -264,27 +269,15 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// antigravityRetryLoop 执行带 URL fallback 的重试循环 // antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:如果账号已限流,根据剩余时间决定等待或切换 // 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" { if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
if remaining < antigravityRateLimitThreshold { log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
// 限流剩余时间较短,等待后继续 p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d", return nil, &AntigravityAccountSwitchError{
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) OriginalAccountID: p.account.ID,
select { RateLimitedModel: p.requestedModel,
case <-p.ctx.Done(): IsStickySession: p.isStickySession,
return nil, p.ctx.Err()
case <-time.After(remaining):
}
} else {
// 限流剩余时间较长,返回账号切换信号
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
} }
} }
} }
...@@ -964,6 +957,16 @@ func isModelNotFoundError(statusCode int, body []byte) bool { ...@@ -964,6 +957,16 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
} }
// Forward 转发 Claude 协议请求(Claude → Gemini 转换) // Forward 转发 Claude 协议请求(Claude → Gemini 转换)
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
sessionID := getSessionID(c) sessionID := getSessionID(c)
...@@ -1583,6 +1586,16 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque ...@@ -1583,6 +1586,16 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
} }
// ForwardGemini 转发 Gemini 协议请求 // ForwardGemini 转发 Gemini 协议请求
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
sessionID := getSessionID(c) sessionID := getSessionID(c)
......
...@@ -803,7 +803,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) { ...@@ -803,7 +803,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope") require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
} }
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) { func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
upstream := &recordingOKUpstream{} upstream := &recordingOKUpstream{}
account := &Account{ account := &Account{
ID: 1, ID: 1,
...@@ -815,19 +815,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi ...@@ -815,19 +815,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
Extra: map[string]any{ Extra: map[string]any{
modelRateLimitsKey: map[string]any{ modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{ "claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339), "rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
}, },
}, },
}, },
} }
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
svc := &AntigravityGatewayService{} svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx, ctx: context.Background(),
prefix: "[test]", prefix: "[test]",
account: account, account: account,
accessToken: "token", accessToken: "token",
...@@ -841,12 +837,16 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi ...@@ -841,12 +837,16 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
}, },
}) })
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result) require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check") var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
} }
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) { func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
upstream := &recordingOKUpstream{} upstream := &recordingOKUpstream{}
account := &Account{ account := &Account{
ID: 2, ID: 2,
......
...@@ -232,6 +232,14 @@ func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, gro ...@@ -232,6 +232,14 @@ func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, gro
return nil return nil
} }
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct { type mockGroupRepoForGateway struct {
groups map[int64]*Group groups map[int64]*Group
getByIDCalls int getByIDCalls int
......
...@@ -313,6 +313,14 @@ type GatewayCache interface { ...@@ -313,6 +313,14 @@ type GatewayCache interface {
// SaveGeminiSession 保存 Gemini 会话 // SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding // Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
// FindAnthropicSession 查找 Anthropic 会话(Trie 匹配)
// Find Anthropic session using Trie matching
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveAnthropicSession 保存 Anthropic 会话
// Save Anthropic session binding
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
} }
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil // derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...@@ -323,21 +331,15 @@ func derefGroupID(groupID *int64) int64 { ...@@ -323,21 +331,15 @@ func derefGroupID(groupID *int64) int64 {
return *groupID return *groupID
} }
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。 // 或请求的模型处于限流状态时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。 // 这确保后续请求不会继续使用不可用的账号。
// //
// shouldClearStickySession checks if an account is in an unschedulable state // shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared. // and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false, // Returns true when account status is error/disabled, schedulable is false,
// within temporary unschedulable period, or model rate limit remaining time // within temporary unschedulable period, or the requested model is rate-limited.
// exceeds stickySessionRateLimitThreshold.
// This ensures subsequent requests won't continue using unavailable accounts. // This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account, requestedModel string) bool { func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil { if account == nil {
...@@ -349,8 +351,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { ...@@ -349,8 +351,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true return true
} }
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话 // 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold { if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true return true
} }
return false return false
...@@ -488,23 +490,25 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -488,23 +490,25 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent) return s.hashContent(cacheableContent)
} }
// 3. Fallback: 使用 system 内容 // 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
var combined strings.Builder
if parsed.System != nil { if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System) systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" { if systemText != "" {
return s.hashContent(systemText) _, _ = combined.WriteString(systemText)
} }
} }
for _, msg := range parsed.Messages {
// 4. 最后 fallback: 使用第一条消息 if m, ok := msg.(map[string]any); ok {
if len(parsed.Messages) > 0 { msgText := s.extractTextFromContent(m["content"])
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" { if msgText != "" {
return s.hashContent(msgText) _, _ = combined.WriteString(msgText)
} }
} }
} }
if combined.Len() > 0 {
return s.hashContent(combined.String())
}
return "" return ""
} }
...@@ -547,6 +551,22 @@ func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, p ...@@ -547,6 +551,22 @@ func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, p
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
} }
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
}
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
return nil
}
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil { if parsed == nil {
return "" return ""
...@@ -1110,7 +1130,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1110,7 +1130,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result.ReleaseFunc() // 释放槽位 result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择 // 继续到负载感知选择
} else { } else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
} }
...@@ -1264,7 +1283,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1264,7 +1283,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
Acquired: true, Acquired: true,
...@@ -2169,9 +2187,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2169,9 +2187,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
} }
...@@ -2272,9 +2287,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2272,9 +2287,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil return account, nil
} }
} }
...@@ -2383,9 +2395,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2383,9 +2395,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
} }
...@@ -2488,9 +2497,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2488,9 +2497,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil return account, nil
} }
} }
......
...@@ -281,6 +281,14 @@ func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, group ...@@ -281,6 +281,14 @@ func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, group
return nil return nil
} }
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background() ctx := context.Background()
......
...@@ -220,6 +220,14 @@ func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, ...@@ -220,6 +220,14 @@ func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64,
return nil return nil
} }
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now() now := time.Now()
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
......
...@@ -23,8 +23,7 @@ import ( ...@@ -23,8 +23,7 @@ import (
// - 临时不可调度且未过期:清理 // - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理 // - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理 // - 正常可调度状态:不清理
// - 模型限流超过阈值:清理 // - 模型限流(任意时长):清理
// - 模型限流未超过阈值:不清理
// //
// TestShouldClearStickySession tests the sticky session clearing logic. // TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including: // Verifies correct behavior for various account states including:
...@@ -35,9 +34,9 @@ func TestShouldClearStickySession(t *testing.T) { ...@@ -35,9 +34,9 @@ func TestShouldClearStickySession(t *testing.T) {
future := now.Add(1 * time.Hour) future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour) past := now.Add(-1 * time.Hour)
// 短限流时间(低于阈值,不应清除粘性会话) // 短限流时间(有限流即清除粘性会话)
shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339) shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339)
// 长限流时间(超过阈值,应清除粘性会话) // 长限流时间(有限流即清除粘性会话)
longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339) longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339)
tests := []struct { tests := []struct {
...@@ -53,7 +52,7 @@ func TestShouldClearStickySession(t *testing.T) { ...@@ -53,7 +52,7 @@ func TestShouldClearStickySession(t *testing.T) {
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true}, {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false}, {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false}, {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false},
// 模型限流测试 // 模型限流测试:有限流即清除
{ {
name: "model rate limited short duration", name: "model rate limited short duration",
account: &Account{ account: &Account{
...@@ -68,7 +67,7 @@ func TestShouldClearStickySession(t *testing.T) { ...@@ -68,7 +67,7 @@ func TestShouldClearStickySession(t *testing.T) {
}, },
}, },
requestedModel: "claude-sonnet-4", requestedModel: "claude-sonnet-4",
want: false, // 低于阈值,不清除 want: true, // 有限流即清除
}, },
{ {
name: "model rate limited long duration", name: "model rate limited long duration",
...@@ -84,7 +83,7 @@ func TestShouldClearStickySession(t *testing.T) { ...@@ -84,7 +83,7 @@ func TestShouldClearStickySession(t *testing.T) {
}, },
}, },
requestedModel: "claude-sonnet-4", requestedModel: "claude-sonnet-4",
want: true, // 超过阈值,清除 want: true, // 有限流即清除
}, },
{ {
name: "model rate limited different model", name: "model rate limited different model",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment