Commit 5e98445b authored by erio's avatar erio
Browse files

feat(antigravity): comprehensive enhancements - model mapping, rate limiting, scheduling & ops

Key changes:
- Upgrade model mapping: Opus 4.5 → Opus 4.6-thinking with precise matching
- Unified rate limiting: scope-level → model-level with Redis snapshot sync
- Load-balanced scheduling by call count with smart retry mechanism
- Force cache billing support
- Model identity injection in prompts with leak prevention
- Thinking mode auto-handling (max_tokens/budget_tokens fix)
- Frontend: whitelist mode toggle, model mapping validation, status indicators
- Gemini session fallback with Redis Trie O(L) matching
- Ops: enhanced concurrency monitoring, account availability, retry logic
- Migration scripts: 049-051 for model mapping unification
parent e617b45b
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
accountID := int64(100)
// 模拟第一轮对话
req1 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}},
},
}
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
if foundUUID != sessionUUID {
t.Errorf("Round 2: expected UUID %s, got %s", sessionUUID, foundUUID)
}
if foundAccID != accountID {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I can help with coding, writing, and more!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Great, help me write some Go code"}}},
},
}
chain3 := BuildGeminiDigestChain(req3)
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
if foundUUID != sessionUUID {
t.Errorf("Round 3: expected UUID %s, got %s", sessionUUID, foundUUID)
}
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 第一个会话
req1 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Tell me about Go programming"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What's the weather today?"}}},
},
}
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
// 查找应该返回最长匹配(账号 3)
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
func TestShortHash(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"empty", []byte{}},
{"simple", []byte("hello world")},
{"json", []byte(`{"role":"user","parts":[{"text":"hello"}]}`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shortHash(tt.input)
// Base36 编码的 uint64 最长 13 个字符
if len(result) > 13 {
t.Errorf("shortHash result too long: %d characters", len(result))
}
// 相同输入应该产生相同输出
result2 := shortHash(tt.input)
if result != result2 {
t.Errorf("shortHash not deterministic: %s vs %s", result, result2)
}
})
}
}
func TestBuildGeminiDigestChain(t *testing.T) {
tests := []struct {
name string
req *antigravity.GeminiRequest
wantLen int // 预期的分段数量
hasEmpty bool // 是否应该是空字符串
}{
{
name: "nil request",
req: nil,
hasEmpty: true,
},
{
name: "empty contents",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{},
},
hasEmpty: true,
},
{
name: "single user message",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
},
wantLen: 1, // u:<hash>
},
{
name: "user and model messages",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi there"}}},
},
},
wantLen: 2, // u:<hash>-m:<hash>
},
{
name: "with system instruction",
req: &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Role: "user",
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
},
wantLen: 2, // s:<hash>-u:<hash>
},
{
name: "conversation with system",
req: &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Role: "user",
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "how are you?"}}},
},
},
wantLen: 4, // s:<hash>-u:<hash>-m:<hash>-u:<hash>
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildGeminiDigestChain(tt.req)
if tt.hasEmpty {
if result != "" {
t.Errorf("expected empty string, got: %s", result)
}
return
}
// 检查分段数量
parts := splitChain(result)
if len(parts) != tt.wantLen {
t.Errorf("expected %d parts, got %d: %s", tt.wantLen, len(parts), result)
}
// 验证每个分段的格式
for _, part := range parts {
if len(part) < 3 || part[1] != ':' {
t.Errorf("invalid part format: %s", part)
}
prefix := part[0]
if prefix != 's' && prefix != 'u' && prefix != 'm' {
t.Errorf("invalid prefix: %c", prefix)
}
}
})
}
}
func TestGenerateGeminiPrefixHash(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
hash3 := GenerateGeminiPrefixHash(2, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
// 相同输入应该产生相同输出
if hash1 != hash2 {
t.Errorf("GenerateGeminiPrefixHash not deterministic: %s vs %s", hash1, hash2)
}
// 不同输入应该产生不同输出
if hash1 == hash3 {
t.Errorf("GenerateGeminiPrefixHash collision for different inputs")
}
// Base64 URL 编码的 12 字节正好是 16 字符
if len(hash1) != 16 {
t.Errorf("expected 16 characters, got %d: %s", len(hash1), hash1)
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
value string
wantUUID string
wantAccID int64
wantOK bool
}{
{
name: "empty",
value: "",
wantOK: false,
},
{
name: "no colon",
value: "abc123",
wantOK: false,
},
{
name: "valid",
value: "uuid-1234:100",
wantUUID: "uuid-1234",
wantAccID: 100,
wantOK: true,
},
{
name: "uuid with colon",
value: "a:b:c:123",
wantUUID: "a:b:c",
wantAccID: 123,
wantOK: true,
},
{
name: "invalid account id",
value: "uuid:abc",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uuid, accID, ok := ParseGeminiSessionValue(tt.value)
if ok != tt.wantOK {
t.Errorf("ok: expected %v, got %v", tt.wantOK, ok)
}
if tt.wantOK {
if uuid != tt.wantUUID {
t.Errorf("uuid: expected %s, got %s", tt.wantUUID, uuid)
}
if accID != tt.wantAccID {
t.Errorf("accountID: expected %d, got %d", tt.wantAccID, accID)
}
}
})
}
}
func TestFormatGeminiSessionValue(t *testing.T) {
result := FormatGeminiSessionValue("test-uuid", 123)
expected := "test-uuid:123"
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
// 验证往返一致性
uuid, accID, ok := ParseGeminiSessionValue(result)
if !ok {
t.Error("ParseGeminiSessionValue failed on formatted value")
}
if uuid != "test-uuid" || accID != 123 {
t.Errorf("round-trip failed: uuid=%s, accID=%d", uuid, accID)
}
}
// splitChain 辅助函数:按 "-" 分割摘要链
func splitChain(chain string) []string {
if chain == "" {
return nil
}
var parts []string
start := 0
for i := 0; i < len(chain); i++ {
if chain[i] == '-' {
parts = append(parts, chain[start:i])
start = i + 1
}
}
if start < len(chain) {
parts = append(parts, chain[start:])
}
return parts
}
func TestDigestChainDifferentSysInstruction(t *testing.T) {
req1 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "SYS_ORIGINAL"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
}
req2 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "SYS_MODIFIED"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
if chain1 == chain2 {
t.Error("Different systemInstruction should produce different chains")
}
}
func TestDigestChainTamperedMiddleContent(t *testing.T) {
req1 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "ORIGINAL_REPLY"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}},
},
}
req2 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "TAMPERED_REPLY"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
if chain1 == chain2 {
t.Error("Tampered middle content should produce different chains")
}
// 验证第一个 user 的 hash 相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
if parts1[1] == parts2[1] {
t.Error("Model reply hash should be different")
}
}
func TestGenerateGeminiDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "gemini:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars prefix and uuid",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "gemini:digest:12345678:abcdefgh",
},
{
name: "short hash and short uuid (less than 8)",
prefixHash: "abc",
uuid: "xyz",
want: "gemini:digest:abc:xyz",
},
{
name: "empty hash and uuid",
prefixHash: "",
uuid: "",
want: "gemini:digest::",
},
{
name: "normal prefix with short uuid",
prefixHash: "abcdefgh12345678",
uuid: "short",
want: "gemini:digest:abcdefgh:short",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateGeminiDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateGeminiDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证确定性:相同输入产生相同输出
t.Run("deterministic", func(t *testing.T) {
hash := "testprefix123456"
uuid := "test-uuid-12345"
result1 := GenerateGeminiDigestSessionKey(hash, uuid)
result2 := GenerateGeminiDigestSessionKey(hash, uuid)
if result1 != result2 {
t.Errorf("GenerateGeminiDigestSessionKey not deterministic: %s vs %s", result1, result2)
}
})
// 验证不同 uuid 产生不同 sessionKey(负载均衡核心逻辑)
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
uuid1 := "uuid0001-session-a"
uuid2 := "uuid0002-session-b"
result1 := GenerateGeminiDigestSessionKey(hash, uuid1)
result2 := GenerateGeminiDigestSessionKey(hash, uuid2)
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
package service
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
const modelRateLimitsKey = "model_rate_limits"
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
model := strings.ToLower(strings.TrimSpace(requestedModel))
if model == "" {
return "", false
// isRateLimitActiveForKey 检查指定 key 的限流是否生效
func (a *Account) isRateLimitActiveForKey(key string) bool {
resetAt := a.modelRateLimitResetAt(key)
return resetAt != nil && time.Now().Before(*resetAt)
}
// getRateLimitRemainingForKey 获取指定 key 的限流剩余时间,0 表示未限流或已过期
func (a *Account) getRateLimitRemainingForKey(key string) time.Duration {
resetAt := a.modelRateLimitResetAt(key)
if resetAt == nil {
return 0
}
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "sonnet") {
return modelRateLimitScopeClaudeSonnet, true
remaining := time.Until(*resetAt)
if remaining > 0 {
return remaining
}
return "", false
return 0
}
func (a *Account) isModelRateLimited(requestedModel string) bool {
scope, ok := resolveModelRateLimitScope(requestedModel)
if !ok {
func (a *Account) isModelRateLimitedWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
resetAt := a.modelRateLimitResetAt(scope)
if resetAt == nil {
modelKey := a.GetMappedModel(requestedModel)
if a.Platform == PlatformAntigravity {
modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel)
}
modelKey = strings.TrimSpace(modelKey)
if modelKey == "" {
return false
}
return time.Now().Before(*resetAt)
return a.isRateLimitActiveForKey(modelKey)
}
// GetModelRateLimitRemainingTime 获取模型限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetModelRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetModelRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
func (a *Account) GetModelRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelKey := a.GetMappedModel(requestedModel)
if a.Platform == PlatformAntigravity {
modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel)
}
modelKey = strings.TrimSpace(modelKey)
if modelKey == "" {
return 0
}
return a.getRateLimitRemainingForKey(modelKey)
}
func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requestedModel string) string {
modelKey := mapAntigravityModel(account, requestedModel)
if modelKey == "" {
return ""
}
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking)
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
modelKey = applyThinkingModelSuffix(modelKey, enabled)
}
return modelKey
}
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
......
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
func TestIsModelRateLimited(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
expected bool
}{
{
name: "official model ID hit - claude-sonnet-4-5",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
expected: true,
},
{
name: "official model ID hit via mapping - request claude-3-5-sonnet, mapped to claude-sonnet-4-5",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "claude-sonnet-4-5",
},
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-3-5-sonnet",
expected: true,
},
{
name: "no rate limit - expired",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
expected: false,
},
{
name: "no rate limit - no matching key",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-flash": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
expected: false,
},
{
name: "no rate limit - unsupported model",
account: &Account{},
requestedModel: "gpt-4",
expected: false,
},
{
name: "no rate limit - empty model",
account: &Account{},
requestedModel: "",
expected: false,
},
{
name: "gemini model hit",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "gemini-3-pro-high",
expected: true,
},
{
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "gemini-3-pro-preview",
expected: true,
},
{
name: "non-antigravity platform - gemini-3-pro-preview NOT mapped",
account: &Account{
Platform: PlatformGemini,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "gemini-3-pro-preview",
expected: false, // gemini 平台不走 antigravity 映射
},
{
name: "antigravity platform - claude-opus-4-5-thinking mapped to opus-4-6-thinking",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-opus-4-5-thinking",
expected: true,
},
{
name: "no scope fallback - claude_sonnet should not match",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude_sonnet": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-3-5-sonnet-20241022",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.isModelRateLimitedWithContext(context.Background(), tt.requestedModel)
if result != tt.expected {
t.Errorf("isModelRateLimited(%q) = %v, want %v", tt.requestedModel, result, tt.expected)
}
})
}
}
func TestIsModelRateLimited_Antigravity_ThinkingAffectsModelKey(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute).Format(time.RFC3339)
account := &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5-thinking": map[string]any{
"rate_limit_reset_at": future,
},
},
},
}
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
if !account.isModelRateLimitedWithContext(ctx, "claude-sonnet-4-5") {
t.Errorf("expected model to be rate limited")
}
}
func TestGetModelRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
future5m := now.Add(5 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "model rate limited - direct hit",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "model rate limited - via mapping",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "claude-sonnet-4-5",
},
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-3-5-sonnet",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "expired rate limit",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "no rate limit data",
account: &Account{},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "no scope fallback",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude_sonnet": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-3-5-sonnet-20241022",
minExpected: 0,
maxExpected: 0,
},
{
name: "antigravity platform - claude-opus-4-5-thinking mapped to opus-4-6-thinking",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-opus-4-5-thinking",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetModelRateLimitRemainingTimeWithContext(context.Background(), tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetModelRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "non-antigravity platform",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "claude scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "gemini_text scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"gemini_text": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "gemini-3-flash",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "expired scope rate limit",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "unsupported model",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "gpt-4",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
future5m := now.Add(5 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "model remaining > scope remaining - returns model",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "scope remaining > model remaining - returns scope",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "only model rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "only scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "neither rate limited",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetRateLimitRemainingTimeWithContext(context.Background(), tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
......@@ -346,47 +346,6 @@ func isInstructionsEmpty(reqBody map[string]any) bool {
return strings.TrimSpace(str) == ""
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions == "" {
return false
}
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) != codexInstructions {
reqBody["instructions"] = codexInstructions
return true
}
return false
}
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
func IsInstructionError(errorMessage string) bool {
if errorMessage == "" {
return false
}
lowerMsg := strings.ToLower(errorMessage)
instructionKeywords := []string{
"instruction",
"instructions",
"system prompt",
"system message",
"invalid prompt",
"prompt format",
}
for _, keyword := range instructionKeywords {
if strings.Contains(lowerMsg, keyword) {
return true
}
}
return false
}
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any {
......
......@@ -187,14 +187,70 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
for input, expected := range cases {
require.Equal(t, expected, normalizeCodexModel(input))
}
}
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时保持不变
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"instructions": "user custom instructions",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "user custom instructions", instructions)
// instructions 未变,但其他字段(如 store、stream)可能被修改
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充内置指令
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotEmpty(t, instructions)
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header)
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, false)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容
require.True(t, result.Modified)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
// Windows 使用 USERPROFILE,Unix 使用 HOME。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
t.Setenv("USERPROFILE", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
......@@ -210,24 +266,6 @@ func setupCodexCache(t *testing.T) {
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
}
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时不修改
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"instructions": "existing instructions",
}
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "existing instructions", instructions)
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
_ = result
}
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充默认值
setupCodexCache(t)
......
......@@ -332,7 +332,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
if shouldClearStickySession(account, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
......@@ -498,7 +498,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
}
......@@ -1087,30 +1087,6 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
)
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformOpenAI,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
......
......@@ -204,6 +204,22 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
return nil
}
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
......
......@@ -67,8 +67,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
platform[acc.Platform] = &PlatformAvailability{
......@@ -86,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
p.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
}
for _, grp := range acc.Groups {
......@@ -118,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
g.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
}
displayGroupID := int64(0)
......@@ -158,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec
}
}
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
......
......@@ -255,3 +255,142 @@ func (s *OpsService) GetConcurrencyStats(
return platform, group, account, &collectedAt, nil
}
// listAllActiveUsersForOps returns all active users with their concurrency settings.
func (s *OpsService) listAllActiveUsersForOps(ctx context.Context) ([]User, error) {
if s == nil || s.userRepo == nil {
return []User{}, nil
}
out := make([]User, 0, 128)
page := 1
for {
users, pageInfo, err := s.userRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
}, UserListFilters{
Status: StatusActive,
})
if err != nil {
return nil, err
}
if len(users) == 0 {
break
}
out = append(out, users...)
if pageInfo != nil && int64(len(out)) >= pageInfo.Total {
break
}
if len(users) < opsAccountsPageSize {
break
}
page++
if page > 10_000 {
log.Printf("[Ops] listAllActiveUsersForOps: aborting after too many pages")
break
}
}
return out, nil
}
// getUsersLoadMapBestEffort returns user load info for the given users.
func (s *OpsService) getUsersLoadMapBestEffort(ctx context.Context, users []User) map[int64]*UserLoadInfo {
if s == nil || s.concurrencyService == nil {
return map[int64]*UserLoadInfo{}
}
if len(users) == 0 {
return map[int64]*UserLoadInfo{}
}
// De-duplicate IDs (and keep the max concurrency to avoid under-reporting).
unique := make(map[int64]int, len(users))
for _, u := range users {
if u.ID <= 0 {
continue
}
if prev, ok := unique[u.ID]; !ok || u.Concurrency > prev {
unique[u.ID] = u.Concurrency
}
}
batch := make([]UserWithConcurrency, 0, len(unique))
for id, maxConc := range unique {
batch = append(batch, UserWithConcurrency{
ID: id,
MaxConcurrency: maxConc,
})
}
out := make(map[int64]*UserLoadInfo, len(batch))
for i := 0; i < len(batch); i += opsConcurrencyBatchChunkSize {
end := i + opsConcurrencyBatchChunkSize
if end > len(batch) {
end = len(batch)
}
part, err := s.concurrencyService.GetUsersLoadBatch(ctx, batch[i:end])
if err != nil {
// Best-effort: return zeros rather than failing the ops UI.
log.Printf("[Ops] GetUsersLoadBatch failed: %v", err)
continue
}
for k, v := range part {
out[k] = v
}
}
return out
}
// GetUserConcurrencyStats returns real-time concurrency usage for all active users.
func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*UserConcurrencyInfo, *time.Time, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, nil, err
}
users, err := s.listAllActiveUsersForOps(ctx)
if err != nil {
return nil, nil, err
}
collectedAt := time.Now()
loadMap := s.getUsersLoadMapBestEffort(ctx, users)
result := make(map[int64]*UserConcurrencyInfo)
for _, u := range users {
if u.ID <= 0 {
continue
}
load := loadMap[u.ID]
currentInUse := int64(0)
waiting := int64(0)
if load != nil {
currentInUse = int64(load.CurrentConcurrency)
waiting = int64(load.WaitingCount)
}
// Skip users with no concurrency activity
if currentInUse == 0 && waiting == 0 {
continue
}
info := &UserConcurrencyInfo{
UserID: u.ID,
UserEmail: u.Email,
Username: u.Username,
CurrentInUse: currentInUse,
MaxCapacity: int64(u.Concurrency),
WaitingInQueue: waiting,
}
if info.MaxCapacity > 0 {
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
}
result[u.ID] = info
}
return result, &collectedAt, nil
}
......@@ -37,6 +37,17 @@ type AccountConcurrencyInfo struct {
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// UserConcurrencyInfo represents real-time concurrency usage for a single user.
type UserConcurrencyInfo struct {
UserID int64 `json:"user_id"`
UserEmail string `json:"user_email"`
Username string `json:"username"`
CurrentInUse int64 `json:"current_in_use"`
MaxCapacity int64 `json:"max_capacity"`
LoadPercentage float64 `json:"load_percentage"`
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
......
......@@ -576,7 +576,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
action = "streamGenerateContent"
}
if account.Platform == PlatformAntigravity {
_, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body)
_, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body, false)
} else {
_, err = s.geminiCompatService.ForwardNative(ctx, c, account, modelName, action, errorLog.Stream, body)
}
......@@ -586,7 +586,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
if s.antigravityGatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "antigravity gateway service not available"}
}
_, err = s.antigravityGatewayService.Forward(ctx, c, account, body)
_, err = s.antigravityGatewayService.Forward(ctx, c, account, body, false)
case PlatformGemini:
if s.geminiCompatService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini gateway service not available"}
......
......@@ -27,6 +27,7 @@ type OpsService struct {
cfg *config.Config
accountRepo AccountRepository
userRepo UserRepository
// getAccountAvailability is a unit-test hook for overriding account availability lookup.
getAccountAvailability func(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error)
......@@ -43,6 +44,7 @@ func NewOpsService(
settingRepo SettingRepository,
cfg *config.Config,
accountRepo AccountRepository,
userRepo UserRepository,
concurrencyService *ConcurrencyService,
gatewayService *GatewayService,
openAIGatewayService *OpenAIGatewayService,
......@@ -55,6 +57,7 @@ func NewOpsService(
cfg: cfg,
accountRepo: accountRepo,
userRepo: userRepo,
concurrencyService: concurrencyService,
gatewayService: gatewayService,
......@@ -424,13 +427,23 @@ func isSensitiveKey(key string) bool {
return false
}
// Whitelist: known non-sensitive fields that contain sensitive substrings
// (e.g., "max_tokens" contains "token" but is just an API parameter).
// Token 计数 / 预算字段不是凭据,应保留用于排错。
// 白名单保持尽量窄,避免误把真实敏感信息"反脱敏"。
switch k {
case "max_tokens", "max_completion_tokens", "max_output_tokens",
"completion_tokens", "prompt_tokens", "total_tokens",
"input_tokens", "output_tokens",
"cache_creation_input_tokens", "cache_read_input_tokens":
case "max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"max_tokens_to_sample",
"budget_tokens",
"prompt_tokens",
"completion_tokens",
"input_tokens",
"output_tokens",
"total_tokens",
"token_count",
"cache_creation_input_tokens",
"cache_read_input_tokens":
return false
}
......@@ -576,7 +589,18 @@ func trimArrayField(root map[string]any, field string, maxBytes int) (map[string
func shrinkToEssentials(root map[string]any) map[string]any {
out := make(map[string]any)
for _, key := range []string{"model", "stream", "max_tokens", "temperature", "top_p", "top_k"} {
for _, key := range []string{
"model",
"stream",
"max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"thinking",
"temperature",
"top_p",
"top_k",
} {
if v, ok := root[key]; ok {
out[key] = v
}
......
package service
import (
"encoding/json"
"testing"
)
func TestIsSensitiveKey_TokenBudgetKeysNotRedacted(t *testing.T) {
t.Parallel()
for _, key := range []string{
"max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"max_tokens_to_sample",
"budget_tokens",
"prompt_tokens",
"completion_tokens",
"input_tokens",
"output_tokens",
"total_tokens",
"token_count",
} {
if isSensitiveKey(key) {
t.Fatalf("expected key %q to NOT be treated as sensitive", key)
}
}
for _, key := range []string{
"authorization",
"Authorization",
"access_token",
"refresh_token",
"id_token",
"session_token",
"token",
"client_secret",
"private_key",
"signature",
} {
if !isSensitiveKey(key) {
t.Fatalf("expected key %q to be treated as sensitive", key)
}
}
}
func TestSanitizeAndTrimRequestBody_PreservesTokenBudgetFields(t *testing.T) {
t.Parallel()
raw := []byte(`{"model":"claude-3","max_tokens":123,"thinking":{"type":"enabled","budget_tokens":456},"access_token":"abc","messages":[{"role":"user","content":"hi"}]}`)
out, _, _ := sanitizeAndTrimRequestBody(raw, 10*1024)
if out == "" {
t.Fatalf("expected non-empty sanitized output")
}
var decoded map[string]any
if err := json.Unmarshal([]byte(out), &decoded); err != nil {
t.Fatalf("unmarshal sanitized output: %v", err)
}
if got, ok := decoded["max_tokens"].(float64); !ok || got != 123 {
t.Fatalf("expected max_tokens=123, got %#v", decoded["max_tokens"])
}
thinking, ok := decoded["thinking"].(map[string]any)
if !ok || thinking == nil {
t.Fatalf("expected thinking object to be preserved, got %#v", decoded["thinking"])
}
if got, ok := thinking["budget_tokens"].(float64); !ok || got != 456 {
t.Fatalf("expected thinking.budget_tokens=456, got %#v", thinking["budget_tokens"])
}
if got := decoded["access_token"]; got != "[REDACTED]" {
t.Fatalf("expected access_token to be redacted, got %#v", got)
}
}
func TestShrinkToEssentials_IncludesThinking(t *testing.T) {
t.Parallel()
root := map[string]any{
"model": "claude-3",
"max_tokens": 100,
"thinking": map[string]any{
"type": "enabled",
"budget_tokens": 200,
},
"messages": []any{
map[string]any{"role": "user", "content": "first"},
map[string]any{"role": "user", "content": "last"},
},
}
out := shrinkToEssentials(root)
if _, ok := out["thinking"]; !ok {
t.Fatalf("expected thinking to be included in essentials: %#v", out)
}
}
......@@ -387,14 +387,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
......@@ -407,14 +399,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
if err != nil {
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
......@@ -423,15 +407,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
resetAt := time.Unix(ts, 0)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
return
}
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
return
}
// 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
......@@ -448,17 +423,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
}
func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
if account == nil || account.Platform != PlatformAnthropic {
return false
}
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
if msg == "" {
return false
}
return strings.Contains(msg, "sonnet")
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
......
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestFilterByMinPriority(t *testing.T) {
t.Run("empty slice", func(t *testing.T) {
result := filterByMinPriority(nil)
require.Empty(t, result)
})
t.Run("single account", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 5}, loadInfo: &AccountLoadInfo{}},
}
result := filterByMinPriority(accounts)
require.Len(t, result, 1)
require.Equal(t, int64(1), result[0].account.ID)
})
t.Run("multiple accounts same priority", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 3}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, Priority: 3}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, Priority: 3}, loadInfo: &AccountLoadInfo{}},
}
result := filterByMinPriority(accounts)
require.Len(t, result, 3)
})
t.Run("filters to min priority only", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 5}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, Priority: 1}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, Priority: 3}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 4, Priority: 1}, loadInfo: &AccountLoadInfo{}},
}
result := filterByMinPriority(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(2), result[0].account.ID)
require.Equal(t, int64(4), result[1].account.ID)
})
}
func TestFilterByMinLoadRate(t *testing.T) {
t.Run("empty slice", func(t *testing.T) {
result := filterByMinLoadRate(nil)
require.Empty(t, result)
})
t.Run("single account", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 1)
require.Equal(t, int64(1), result[0].account.ID)
})
t.Run("multiple accounts same load rate", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 3)
})
t.Run("filters to min load rate only", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 80}},
{account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 4}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(2), result[0].account.ID)
require.Equal(t, int64(4), result[1].account.ID)
})
t.Run("zero load rate", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(1), result[0].account.ID)
require.Equal(t, int64(3), result[1].account.ID)
})
}
func TestSelectByLRU(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
muchEarlier := now.Add(-2 * time.Hour)
t.Run("empty slice", func(t *testing.T) {
result := selectByLRU(nil, false)
require.Nil(t, result)
})
t.Run("single account", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(1), result.account.ID)
})
t.Run("selects least recently used", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(2), result.account.ID)
})
t.Run("nil LastUsedAt preferred over non-nil", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(2), result.account.ID)
})
t.Run("multiple nil LastUsedAt random selection", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
}
// 多次调用应该随机选择,验证结果都在候选范围内
validIDs := map[int64]bool{1: true, 2: true, 3: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.True(t, validIDs[result.account.ID], "selected ID should be one of the candidates")
}
})
t.Run("multiple same LastUsedAt random selection", func(t *testing.T) {
sameTime := now
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &sameTime}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: &sameTime}, loadInfo: &AccountLoadInfo{}},
}
// 多次调用应该随机选择
validIDs := map[int64]bool{1: true, 2: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.True(t, validIDs[result.account.ID], "selected ID should be one of the candidates")
}
})
t.Run("preferOAuth selects from OAuth accounts when multiple nil", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: nil, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}},
}
// preferOAuth 时,应该从 OAuth 类型中选择
oauthIDs := map[int64]bool{2: true, 3: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, true)
require.NotNil(t, result)
require.True(t, oauthIDs[result.account.ID], "should select from OAuth accounts")
}
})
t.Run("preferOAuth falls back to all when no OAuth", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
}
// 没有 OAuth 时,从所有候选中选择
validIDs := map[int64]bool{1: true, 2: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, true)
require.NotNil(t, result)
require.True(t, validIDs[result.account.ID])
}
})
t.Run("preferOAuth only affects same LastUsedAt accounts", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &earlier, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: &now, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, true)
require.NotNil(t, result)
// 有不同 LastUsedAt 时,按时间选择最早的,不受 preferOAuth 影响
require.Equal(t, int64(1), result.account.ID)
})
}
func TestLayeredFilterIntegration(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
muchEarlier := now.Add(-2 * time.Hour)
t.Run("full layered selection", func(t *testing.T) {
// 模拟真实场景:多个账号,不同优先级、负载率、最后使用时间
accounts := []accountWithLoad{
// 优先级 1,负载 50%
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
// 优先级 1,负载 20%(最低)
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
// 优先级 1,负载 20%(最低),更早使用
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
// 优先级 2(较低优先)
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
// 1. 取优先级最小的集合 → ID: 1, 2, 3
step1 := filterByMinPriority(accounts)
require.Len(t, step1, 3)
// 2. 取负载率最低的集合 → ID: 2, 3
step2 := filterByMinLoadRate(step1)
require.Len(t, step2, 2)
// 3. LRU 选择 → ID: 3(muchEarlier 最早)
selected := selectByLRU(step2, false)
require.NotNil(t, selected)
require.Equal(t, int64(3), selected.account.ID)
})
t.Run("all same priority and load rate", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
}
step1 := filterByMinPriority(accounts)
require.Len(t, step1, 3)
step2 := filterByMinLoadRate(step1)
require.Len(t, step2, 3)
// LRU 选择最早的
selected := selectByLRU(step2, false)
require.NotNil(t, selected)
require.Equal(t, int64(3), selected.account.ID)
})
}
......@@ -151,6 +151,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int
return s.accountRepo.GetByID(fallbackCtx, accountID)
}
// UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效)
func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error {
if s.cache == nil || account == nil {
return nil
}
return s.cache.SetAccount(ctx, account)
}
func (s *SchedulerSnapshotService) runInitialRebuild() {
if s.cache == nil {
return
......
......@@ -23,32 +23,90 @@ import (
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
// - 模型限流超过阈值:清理
// - 模型限流未超过阈值:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
// nil account, error/disabled status, unschedulable, temporary unschedulable,
// and model rate limiting scenarios.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
// 短限流时间(低于阈值,不应清除粘性会话)
shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339)
// 长限流时间(超过阈值,应清除粘性会话)
longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
want bool
}{
{name: "nil account", account: nil, want: false},
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
{name: "nil account", account: nil, requestedModel: "", want: false},
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, requestedModel: "", want: true},
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, requestedModel: "", want: true},
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, requestedModel: "", want: true},
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false},
// 模型限流测试
{
name: "model rate limited short duration",
account: &Account{
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude-sonnet-4": map[string]any{
"rate_limit_reset_at": shortRateLimitReset,
},
},
},
},
requestedModel: "claude-sonnet-4",
want: false, // 低于阈值,不清除
},
{
name: "model rate limited long duration",
account: &Account{
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude-sonnet-4": map[string]any{
"rate_limit_reset_at": longRateLimitReset,
},
},
},
},
requestedModel: "claude-sonnet-4",
want: true, // 超过阈值,清除
},
{
name: "model rate limited different model",
account: &Account{
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude-sonnet-4": map[string]any{
"rate_limit_reset_at": longRateLimitReset,
},
},
},
},
requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, shouldClearStickySession(tt.account))
require.Equal(t, tt.want, shouldClearStickySession(tt.account, tt.requestedModel))
})
}
}
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ============ 临时限流单元测试 ============
// TestMatchTempUnschedKeyword 测试关键词匹配函数
func TestMatchTempUnschedKeyword(t *testing.T) {
tests := []struct {
name string
body string
keywords []string
want string
}{
{
name: "match_first",
body: "server is overloaded",
keywords: []string{"overloaded", "capacity"},
want: "overloaded",
},
{
name: "match_second",
body: "no capacity available",
keywords: []string{"overloaded", "capacity"},
want: "capacity",
},
{
name: "no_match",
body: "internal error",
keywords: []string{"overloaded", "capacity"},
want: "",
},
{
name: "empty_body",
body: "",
keywords: []string{"overloaded"},
want: "",
},
{
name: "empty_keywords",
body: "server is overloaded",
keywords: []string{},
want: "",
},
{
name: "whitespace_keyword",
body: "server is overloaded",
keywords: []string{" ", "overloaded"},
want: "overloaded",
},
{
// matchTempUnschedKeyword 期望 body 已经是小写的
// 所以要测试大小写不敏感匹配,需要传入小写的 body
name: "case_insensitive_body_lowered",
body: "server is overloaded", // body 已经是小写
keywords: []string{"OVERLOADED"}, // keyword 会被转为小写比较
want: "OVERLOADED",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := matchTempUnschedKeyword(tt.body, tt.keywords)
require.Equal(t, tt.want, got)
})
}
}
// TestAccountIsSchedulable_TempUnschedulable 测试临时限流账号不可调度
func TestAccountIsSchedulable_TempUnschedulable(t *testing.T) {
future := time.Now().Add(10 * time.Minute)
past := time.Now().Add(-10 * time.Minute)
tests := []struct {
name string
account *Account
want bool
}{
{
name: "temp_unschedulable_active",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &future,
},
want: false,
},
{
name: "temp_unschedulable_expired",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &past,
},
want: true,
},
{
name: "no_temp_unschedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: nil,
},
want: true,
},
{
name: "temp_unschedulable_with_rate_limit",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &future,
RateLimitResetAt: &past, // 过期的限流不影响
},
want: false, // 临时限流生效
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsSchedulable()
require.Equal(t, tt.want, got)
})
}
}
// TestAccount_IsTempUnschedulableEnabled 测试临时限流开关
func TestAccount_IsTempUnschedulableEnabled(t *testing.T) {
tests := []struct {
name string
account *Account
want bool
}{
{
name: "enabled",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
},
},
want: true,
},
{
name: "disabled",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_enabled": false,
},
},
want: false,
},
{
name: "not_set",
account: &Account{
Credentials: map[string]any{},
},
want: false,
},
{
name: "nil_credentials",
account: &Account{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsTempUnschedulableEnabled()
require.Equal(t, tt.want, got)
})
}
}
// TestAccount_GetTempUnschedulableRules 测试获取临时限流规则
func TestAccount_GetTempUnschedulableRules(t *testing.T) {
tests := []struct {
name string
account *Account
wantCount int
}{
{
name: "has_rules",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(5),
},
map[string]any{
"error_code": float64(500),
"keywords": []any{"internal"},
"duration_minutes": float64(10),
},
},
},
},
wantCount: 2,
},
{
name: "empty_rules",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_rules": []any{},
},
},
wantCount: 0,
},
{
name: "no_rules",
account: &Account{
Credentials: map[string]any{},
},
wantCount: 0,
},
{
name: "nil_credentials",
account: &Account{},
wantCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rules := tt.account.GetTempUnschedulableRules()
require.Len(t, rules, tt.wantCount)
})
}
}
// TestTempUnschedulableRule_Parse 测试规则解析
func TestTempUnschedulableRule_Parse(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded", "capacity"},
"duration_minutes": float64(5),
},
},
},
}
rules := account.GetTempUnschedulableRules()
require.Len(t, rules, 1)
rule := rules[0]
require.Equal(t, 503, rule.ErrorCode)
require.Equal(t, []string{"overloaded", "capacity"}, rule.Keywords)
require.Equal(t, 5, rule.DurationMinutes)
}
// TestTruncateTempUnschedMessage 测试消息截断
func TestTruncateTempUnschedMessage(t *testing.T) {
tests := []struct {
name string
body []byte
maxBytes int
want string
}{
{
name: "short_message",
body: []byte("short"),
maxBytes: 100,
want: "short",
},
{
// 截断后会 TrimSpace,所以末尾的空格会被移除
name: "truncate_long_message",
body: []byte("this is a very long message that needs to be truncated"),
maxBytes: 20,
want: "this is a very long", // 截断后 TrimSpace
},
{
name: "empty_body",
body: []byte{},
maxBytes: 100,
want: "",
},
{
name: "zero_max_bytes",
body: []byte("test"),
maxBytes: 0,
want: "",
},
{
name: "whitespace_trimmed",
body: []byte(" test "),
maxBytes: 100,
want: "test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := truncateTempUnschedMessage(tt.body, tt.maxBytes)
require.Equal(t, tt.want, got)
})
}
}
// TestTempUnschedState 测试临时限流状态结构
func TestTempUnschedState(t *testing.T) {
now := time.Now()
until := now.Add(5 * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: 503,
MatchedKeyword: "overloaded",
RuleIndex: 0,
ErrorMessage: "Server is overloaded",
}
require.Equal(t, 503, state.StatusCode)
require.Equal(t, "overloaded", state.MatchedKeyword)
require.Equal(t, 0, state.RuleIndex)
// 验证时间戳
require.Equal(t, until.Unix(), state.UntilUnix)
require.Equal(t, now.Unix(), state.TriggeredAtUnix)
}
// TestAccount_TempUnschedulableUntil 测试临时限流时间字段
func TestAccount_TempUnschedulableUntil(t *testing.T) {
future := time.Now().Add(10 * time.Minute)
past := time.Now().Add(-10 * time.Minute)
tests := []struct {
name string
account *Account
schedulable bool
}{
{
name: "active_temp_unsched_not_schedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &future,
},
schedulable: false,
},
{
name: "expired_temp_unsched_is_schedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &past,
},
schedulable: true,
},
{
name: "nil_temp_unsched_is_schedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: nil,
},
schedulable: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsSchedulable()
require.Equal(t, tt.schedulable, got)
})
}
}
-- Force set default Antigravity model_mapping.
--
-- Notes:
-- - Applies to both Antigravity OAuth and Upstream accounts.
-- - Overwrites existing credentials.model_mapping.
-- - Removes legacy credentials.model_whitelist.
UPDATE accounts
SET credentials = (COALESCE(credentials, '{}'::jsonb) - 'model_whitelist' - 'model_mapping') || '{
"model_mapping": {
"claude-opus-4-6": "claude-opus-4-6",
"claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}
}'::jsonb
WHERE platform = 'antigravity'
AND deleted_at IS NULL;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment