"backend/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "94e067a2e2db42f7ba6d55451d6fedb1d10ca196"
Unverified Commit 149e4267 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #523 from touwaeriol/feat/antigravity-improvements

feat: Antigravity improvements and scope-to-model rate limiting refactor
parents 5fa93ebd 9a479d1b
package service
import (
"strconv"
"strings"
"time"
gocache "github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const digestSessionTTL = 5 * time.Minute
// sessionEntry flat cache 条目
type sessionEntry struct {
uuid string
accountID int64
}
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type DigestSessionStore struct {
cache *gocache.Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func NewDigestSessionStore() *DigestSessionStore {
return &DigestSessionStore{
cache: gocache.New(digestSessionTTL, time.Minute),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
if digestChain == "" {
return
}
ns := buildNS(groupID, prefixHash)
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.cache.Delete(ns + oldDigestChain)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" {
return "", 0, "", false
}
ns := buildNS(groupID, prefixHash)
chain := digestChain
for {
if val, ok := s.cache.Get(ns + chain); ok {
if e, ok := val.(*sessionEntry); ok {
return e.uuid, e.accountID, chain, true
}
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", 0, "", false
}
chain = chain[:i]
}
}
// buildNS 构建 namespace 前缀
func buildNS(groupID int64, prefixHash string) string {
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
}
//go:build unit
package service
import (
"fmt"
"sync"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
// 保存短链
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
// 用长链查找,应前缀匹配到短链
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-short", uuid)
assert.Equal(t, int64(10), accountID)
assert.Equal(t, "u:a-m:b", matchedChain)
}
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "uuid-3", uuid)
assert.Equal(t, int64(3), accountID)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
store := NewDigestSessionStore()
// 第一轮:保存 "u:a-m:b"
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
// 旧链 "u:a-m:b" 应已被删除
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found, "old chain should be deleted")
// 新链应能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
store := NewDigestSessionStore()
// 相同系统提示词,不同用户提示词
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(200), accountID)
}
func TestDigestSessionStore_NoMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 完全不同的 chain
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
// 不同 prefixHash 应隔离
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 不同 groupID 应隔离
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
store := NewDigestSessionStore()
// 空链不应保存
store.Save(1, "prefix", "", "uuid-1", 100, "")
_, _, _, found := store.Find(1, "prefix", "")
assert.False(t, found)
}
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 立即应该能找到
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
// 等待过期 + 清理周期
time.Sleep(300 * time.Millisecond)
// 过期后应找不到
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
store := NewDigestSessionStore()
var wg sync.WaitGroup
const goroutines = 50
const operations = 100
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
prefix := fmt.Sprintf("prefix-%d", id%5)
for i := 0; i < operations; i++ {
chain := fmt.Sprintf("u:%d-m:%d", id, i)
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
store.Save(1, prefix, chain, uuid, int64(id), "")
store.Find(1, prefix, chain)
}
}(g)
}
wg.Wait()
}
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
store := NewDigestSessionStore()
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
require.True(t, found, "should find session: %s", sess.chain)
assert.Equal(t, sess.uuid, uuid)
assert.Equal(t, sess.accountID, accountID)
}
// 验证继续对话的场景
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
store := NewDigestSessionStore()
// 插入 1000 个会话
for i := 0; i < 1000; i++ {
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
// 查找性能测试
start := time.Now()
const lookups = 10000
for i := 0; i < lookups; i++ {
idx := i % 1000
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
_, _, _, found := store.Find(1, "prefix", chain)
assert.True(t, found)
}
elapsed := time.Since(start)
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
}
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
// 精确匹配
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
// 前缀匹配(截断后命中)
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
}
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
store := NewDigestSessionStore()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for conv := 0; conv < 100; conv++ {
var prevMatchedChain string
for round := 0; round < 10; round++ {
chain := fmt.Sprintf("s:sys-u:user%d", conv)
for r := 0; r < round; r++ {
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
}
uuid := fmt.Sprintf("uuid-conv%d", conv)
_, _, matched, _ := store.Find(1, "prefix", chain)
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
prevMatchedChain = matched
_ = prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount := store.cache.ItemCount()
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
}
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
// 使用极短 TTL 验证大量写入后 cache 能被清理
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
for i := 0; i < 500; i++ {
chain := fmt.Sprintf("u:user%d", i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
assert.Equal(t, 500, store.cache.ItemCount())
// 等待 TTL + 清理周期
time.Sleep(300 * time.Millisecond)
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
}
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
store := NewDigestSessionStore()
// 保存 chain
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
// 仍然能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mocks (scoped to this file by naming convention)
// ---------------------------------------------------------------------------
// epFixedUpstream returns a fixed response for every request.
type epFixedUpstream struct {
statusCode int
body string
calls int
}
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.calls++
return &http.Response{
StatusCode: u.statusCode,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(u.body)),
}, nil
}
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
// epAccountRepo records SetTempUnschedulable / SetError calls.
type epAccountRepo struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
}
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrCalls++
return nil
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func saveAndSetBaseURLs(t *testing.T) {
t.Helper()
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvail := antigravity.DefaultURLAvailability
antigravity.BaseURLs = []string{"https://ep-test.example"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
t.Cleanup(func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvail
})
}
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
return antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[ep-test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: handleError,
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
tests := []struct {
name string
upstreamStatus int
upstreamBody string
customCodes []any
expectHandleError int
expectUpstream int
expectStatusCode int
}{
{
name: "429_in_custom_codes_matched",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(429)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "429_not_in_custom_codes_skipped",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "500_in_custom_codes_matched",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(500)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 500,
},
{
name: "500_not_in_custom_codes_skipped",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(429)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": tt.customCodes,
},
}
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
})
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_TempUnschedulable
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
tempRulesAccount := func(rules []any) *Account {
return &Account{
ID: 200,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": rules,
},
}
}
overloadedRule := map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
}
rateLimitRule := map[string]any{
"error_code": float64(429),
"keywords": []any{"rate limited keyword"},
"duration_minutes": float64(5),
}
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{rateLimitRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
// Use a short-lived context: the backoff sleep (~1s) will be
// interrupted, proving the code entered the default retry path
// instead of breaking early via error policy.
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
result, err := svc.antigravityRetryLoop(p)
// Context cancellation during backoff proves default retry was entered
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
})
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NilRateLimitService
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
// rateLimitService is nil — must not panic
svc := &AntigravityGatewayService{rateLimitService: nil}
account := &Account{
ID: 300,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
// Should not panic; enters the default retry path (eventually times out)
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1)
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
// Plain OAuth account with no error policy configured
account := &Account{
ID: 400,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "no_policy_oauth_returns_none",
account: &Account{
ID: 1,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
// no custom error codes, no temp rules
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_hit_returns_matched",
account: &Account{
ID: 2,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyMatched,
},
{
name: "custom_error_codes_miss_returns_skipped",
account: &Account{
ID: 3,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 503,
body: []byte(`"error"`),
expected: ErrorPolicySkipped,
},
{
name: "temp_unschedulable_hit_returns_temp_unscheduled",
account: &Account{
ID: 4,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{
ID: 5,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`random msg`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_override_temp_unschedulable",
account: &Account{
ID: 6,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
})
}
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------
func TestApplyErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expectedHandled bool
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
{
name: "none_not_handled",
account: &Account{
ID: 10,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
handleErrorCalls: 0,
},
{
name: "skipped_handled_no_handleError",
account: &Account{
ID: 11,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 0,
},
{
name: "matched_handled_calls_handleError",
account: &Account{
ID: 12,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 1,
},
{
name: "temp_unscheduled_returns_switch_error",
account: &Account{
ID: 13,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedSwitchErr: true,
handleErrorCalls: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{
rateLimitService: rlSvc,
}
var handleErrorCount int
p := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: tt.account,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
},
isStickySession: true,
}
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, retErr, &switchErr)
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
} else {
require.NoError(t, retErr)
}
})
}
}
// ---------------------------------------------------------------------------
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
// ---------------------------------------------------------------------------
type errorPolicyRepoStub struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
lastErrorMsg string
}
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrCalls++
r.lastErrorMsg = errorMsg
return nil
}
...@@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co ...@@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil return nil
} }
...@@ -216,30 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context ...@@ -216,30 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil return nil
} }
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct { type mockGroupRepoForGateway struct {
groups map[int64]*Group groups map[int64]*Group
getByIDCalls int getByIDCalls int
......
...@@ -6,9 +6,19 @@ import ( ...@@ -6,9 +6,19 @@ import (
"fmt" "fmt"
"math" "math"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
) )
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果 // ParsedRequest 保存网关请求的预解析结果
// //
// 性能优化说明: // 性能优化说明:
...@@ -22,20 +32,22 @@ import ( ...@@ -22,20 +32,22 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层 // 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct { type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发) Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称 Model string // 请求的模型名称
Stream bool // 是否为流式请求 Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和) MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容 System any // system 字段内容
Messages []any // messages 数组 Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens int // max_tokens 值(用于探测请求拦截) MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
} }
// ParseGatewayRequest 解析网关请求体并返回结构化结果 // ParseGatewayRequest 解析网关请求体并返回结构化结果。
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { // 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any var req map[string]any
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
return nil, err return nil, err
...@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { ...@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID parsed.MetadataUserID = userID
} }
} }
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。 switch protocol {
if system, ok := req["system"]; ok { case domain.PlatformGemini:
parsed.HasSystem = true // Gemini 原生格式: systemInstruction.parts / contents
parsed.System = system if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
} if parts, ok := sysInst["parts"].([]any); ok {
if messages, ok := req["messages"].([]any); ok { parsed.System = parts
parsed.Messages = messages }
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
} }
// thinking: {type: "enabled"} // thinking: {type: "enabled"}
......
...@@ -4,12 +4,13 @@ import ( ...@@ -4,12 +4,13 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestParseGatewayRequest(t *testing.T) { func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream) require.True(t, parsed.Stream)
...@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) { ...@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled) require.True(t, parsed.ThinkingEnabled)
...@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { ...@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
func TestParseGatewayRequest_MaxTokens(t *testing.T) { func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, parsed.MaxTokens) require.Equal(t, 1, parsed.MaxTokens)
} }
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, parsed.MaxTokens) require.Equal(t, 0, parsed.MaxTokens)
} }
func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`) body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err) require.NoError(t, err)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem) require.True(t, parsed.HasSystem)
...@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { ...@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
func TestParseGatewayRequest_InvalidModelType(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`) body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body) _, err := ParseGatewayRequest(body, "")
require.Error(t, err) require.Error(t, err)
} }
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`) body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body) _, err := ParseGatewayRequest(body, "")
require.Error(t, err) require.Error(t, err)
} }
// ============ Gemini 原生格式解析测试 ============
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
body := []byte(`{
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
require.Nil(t, parsed.System, "no systemInstruction means nil System")
}
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
body := []byte(`{
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
parts, ok := parsed.System.([]any)
require.True(t, ok)
require.Len(t, parts, 1)
partMap, ok := parts[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "You are a helpful assistant.", partMap["text"])
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
body := []byte(`{
"model": "gemini-2.5-pro",
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Equal(t, "gemini-2.5-pro", parsed.Model)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
// Gemini 格式下 system/messages 字段应被忽略
body := []byte(`{
"system": "should be ignored",
"messages": [{"role": "user", "content": "ignored"}],
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
require.Nil(t, parsed.System, "no systemInstruction = nil System")
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
}
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
body := []byte(`{"contents": []}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Empty(t, parsed.Messages)
}
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
body := []byte(`{"model": "gemini-2.5-flash"}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Nil(t, parsed.Messages)
require.Equal(t, "gemini-2.5-flash", parsed.Model)
}
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
body := []byte(`{
"system": "real system",
"messages": [{"role": "user", "content": "real content"}],
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
"systemInstruction": {"parts": [{"text": "ignored"}]}
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
require.NoError(t, err)
require.True(t, parsed.HasSystem)
require.Equal(t, "real system", parsed.System)
require.Len(t, parsed.Messages, 1)
msg, ok := parsed.Messages[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "real content", msg["content"])
}
func TestFilterThinkingBlocks(t *testing.T) { func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool { containsThinkingBlock := func(body []byte) bool {
var req map[string]any var req map[string]any
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -17,6 +16,7 @@ import ( ...@@ -17,6 +16,7 @@ import (
"os" "os"
"regexp" "regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
...@@ -245,9 +246,6 @@ var ( ...@@ -245,9 +246,6 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers(参考CRS项目) // allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
"accept": true, "accept": true,
...@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{ ...@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。 // GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
// //
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service. // GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities. // Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface { type GatewayCache interface {
...@@ -295,32 +286,6 @@ type GatewayCache interface { ...@@ -295,32 +286,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable // Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息(uuid, accountID)
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
// FindAnthropicSession 查找 Anthropic 会话(Trie 匹配)
// Find Anthropic session using Trie matching
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveAnthropicSession 保存 Anthropic 会话
// Save Anthropic session binding
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
} }
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil // derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...@@ -415,6 +380,7 @@ type GatewayService struct { ...@@ -415,6 +380,7 @@ type GatewayService struct {
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository userGroupRateRepo UserGroupRateRepository
cache GatewayCache cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService billingService *BillingService
...@@ -448,6 +414,7 @@ func NewGatewayService( ...@@ -448,6 +414,7 @@ func NewGatewayService(
deferredService *DeferredService, deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider, claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache, sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService { ) *GatewayService {
return &GatewayService{ return &GatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -457,6 +424,7 @@ func NewGatewayService( ...@@ -457,6 +424,7 @@ func NewGatewayService(
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo, userGroupRateRepo: userGroupRateRepo,
cache: cache, cache: cache,
digestStore: digestStore,
cfg: cfg, cfg: cfg,
schedulerSnapshot: schedulerSnapshot, schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService, concurrencyService: concurrencyService,
...@@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent) return s.hashContent(cacheableContent)
} }
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串 // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil { if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System) systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" { if systemText != "" {
...@@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
} }
for _, msg := range parsed.Messages { for _, msg := range parsed.Messages {
if m, ok := msg.(map[string]any); ok { if m, ok := msg.(map[string]any); ok {
msgText := s.extractTextFromContent(m["content"]) if content, exists := m["content"]; exists {
if msgText != "" { // Anthropic: messages[].content
_, _ = combined.WriteString(msgText) if msgText := s.extractTextFromContent(content); msgText != "" {
_, _ = combined.WriteString(msgText)
}
} else if parts, ok := m["parts"].([]any); ok {
// Gemini: contents[].parts[].text
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
_, _ = combined.WriteString(text)
}
}
}
} }
} }
} }
...@@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID ...@@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) // FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息(uuid, accountID) // 返回最长匹配的会话信息(uuid, accountID)
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.cache == nil { if digestChain == "" || s.digestStore == nil {
return "", 0, false return "", 0, "", false
} }
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) return s.digestStore.Find(groupID, prefixHash, digestChain)
} }
// SaveGeminiSession 保存 Gemini 会话 // SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.cache == nil { if digestChain == "" || s.digestStore == nil {
return nil return nil
} }
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
} }
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) // FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.cache == nil { if digestChain == "" || s.digestStore == nil {
return "", 0, false return "", 0, "", false
} }
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain) return s.digestStore.Find(groupID, prefixHash, digestChain)
} }
// SaveAnthropicSession 保存 Anthropic 会话 // SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.cache == nil { if digestChain == "" || s.digestStore == nil {
return nil return nil
} }
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
} }
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
...@@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string { ...@@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
} }
func (s *GatewayService) hashContent(content string) string { func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content)) h := xxhash.Sum64String(content)
return hex.EncodeToString(hash[:16]) // 32字符 return strconv.FormatUint(h, 36)
} }
// replaceModelInBody 替换请求体中的model字段 // replaceModelInBody 替换请求体中的model字段
...@@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
} }
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return a.account.LastUsedAt.Before(*b.account.LastUsedAt) return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
} }
}) })
shuffleWithinSortGroups(routingAvailable)
// 4. 尝试获取槽位 // 4. 尝试获取槽位
for _, item := range routingAvailable { for _, item := range routingAvailable {
...@@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil return result, nil
} }
} else { } else {
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad var available []accountWithLoad
for _, acc := range candidates { for _, acc := range candidates {
loadInfo := loadMap[acc.ID] loadInfo := loadMap[acc.ID]
...@@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致) // 分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 { for len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available)) // 1. 取优先级最小的集合
modelToAccountIDs := make(map[string][]int64) candidates := filterByMinPriority(available)
for _, item := range available { // 2. 取负载率最低的集合
mappedModel := mapAntigravityModel(item.account, requestedModel) candidates = filterByMinLoadRate(candidates)
if mappedModel == "" { // 3. LRU 选择最久未用的账号
continue selected := selectByLRU(candidates, preferOAuth)
} if selected == nil {
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID) break
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
} }
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 移除已尝试的账号,重新选择 result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
selectedID := selected.account.ID if err == nil && result.Acquired {
newAvailable := make([]accountWithLoad, 0, len(available)-1) // 会话数量限制检查
for _, acc := range available { if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
if acc.account.ID != selectedID { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
newAvailable = append(newAvailable, acc) } else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
} }
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
available = newAvailable
} }
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 移除已尝试的账号,重新进行分层过滤 // 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1) newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available { for _, acc := range available {
if acc.account.ID != selectedID { if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc) newAvailable = append(newAvailable, acc)
}
} }
available = newAvailable
} }
available = newAvailable
} }
} }
...@@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { ...@@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
return a.LastUsedAt.Before(*b.LastUsedAt) return a.LastUsedAt.Before(*b.LastUsedAt)
} }
}) })
shuffleWithinPriorityAndLastUsed(accounts)
} }
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用) // shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调 // 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
// 如果有多个账号具有相同的最小调用次数,则随机选择一个 func shuffleWithinSortGroups(accounts []accountWithLoad) {
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad { if len(accounts) <= 1 {
if len(accounts) == 0 { return
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
} }
i := 0
// 1. 计算平均调用次数(用于新账号冷启动) for i < len(accounts) {
var totalCallCount int64 j := i + 1
var countWithCalls int for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
for _, acc := range accounts { j++
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 { }
totalCallCount += info.CallCount if j-i > 1 {
countWithCalls++ mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
} }
i = j
} }
}
var avgCallCount int64 // sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
if countWithCalls > 0 { func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
avgCallCount = totalCallCount / int64(countWithCalls) if a.account.Priority != b.account.Priority {
return false
} }
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
// 2. 获取每个账号的有效调用次数 return false
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
} }
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
}
// 3. 找到最小调用次数 // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
minCount := getEffectiveCallCount(accounts[0]) func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
for _, acc := range accounts[1:] { if len(accounts) <= 1 {
if c := getEffectiveCallCount(acc); c < minCount { return
minCount = c
}
} }
i := 0
// 4. 收集所有具有最小调用次数的账号 for i < len(accounts) {
var candidateIdxs []int j := i + 1
for i, acc := range accounts { for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
if getEffectiveCallCount(acc) == minCount { j++
candidateIdxs = append(candidateIdxs, i) }
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
} }
i = j
} }
}
// 5. 如果只有一个候选,直接返回 // sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
if len(candidateIdxs) == 1 { func sameAccountGroup(a, b *Account) bool {
return &accounts[candidateIdxs[0]] if a.Priority != b.Priority {
return false
} }
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
}
// 6. preferOAuth 处理 // sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
if preferOAuth { func sameLastUsedAt(a, b *time.Time) bool {
var oauthIdxs []int switch {
for _, idx := range candidateIdxs { case a == nil && b == nil:
if accounts[idx].account.Type == AccountTypeOAuth { return true
oauthIdxs = append(oauthIdxs, idx) case a == nil || b == nil:
} return false
} default:
if len(oauthIdxs) > 0 { return a.Unix() == b.Unix()
candidateIdxs = oauthIdxs
}
} }
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
} }
// sortCandidatesForFallback 根据配置选择排序策略 // sortCandidatesForFallback 根据配置选择排序策略
...@@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) { ...@@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
...@@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { ...@@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil return normalized, nil
} }
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope,跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group // GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group // It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
......
...@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { ...@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body, "")
if err != nil { if err != nil {
b.Fatalf("解析请求失败: %v", err) b.Fatalf("解析请求失败: %v", err)
} }
......
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
statusCode int
expected bool
}{
{"401_failover", 401, true},
{"403_failover", 403, true},
{"429_failover", 429, true},
{"529_failover", 529, true},
{"500_failover", 500, true},
{"502_failover", 502, true},
{"503_failover", 503, true},
{"400_no_failover", 400, false},
{"404_no_failover", 404, false},
{"422_no_failover", 422, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "gemini_apikey_custom_codes_hit",
account: &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 429,
body: []byte(`{"error":"rate limited"}`),
expected: ErrorPolicyMatched,
},
{
name: "gemini_apikey_custom_codes_miss",
account: &Account{
ID: 101,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicySkipped,
},
{
name: "gemini_apikey_no_custom_codes_returns_none",
account: &Account{
ID: 102,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicyNone,
},
{
name: "gemini_apikey_temp_unschedulable_hit",
account: &Account{
ID: 103,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{
ID: 104,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicyIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
account *Account
statusCode int
respBody []byte
expectFailover bool // expect UpstreamFailoverError
expectHandleError bool // expect handleGeminiUpstreamError to be called
expectShouldFailover bool // for None path, whether shouldFailover triggers
}{
{
name: "custom_codes_matched_429_failover",
account: &Account{
ID: 200,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
},
{
name: "custom_codes_skipped_500_no_failover",
account: &Account{
ID: 201,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
respBody: []byte(`{"error":"internal"}`),
expectFailover: false,
expectHandleError: false,
},
{
name: "temp_unschedulable_matched_failover",
account: &Account{
ID: 202,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
respBody: []byte(`overloaded`),
expectFailover: true,
expectHandleError: true,
},
{
name: "no_policy_429_failover_via_shouldFailover",
account: &Account{
ID: 203,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
expectShouldFailover: true,
},
{
name: "no_policy_400_no_failover",
account: &Account{
ID: 204,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 400,
respBody: []byte(`{"error":"bad request"}`),
expectFailover: false,
expectHandleError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &geminiErrorPolicyRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &GeminiMessagesCompatService{
accountRepo: repo,
rateLimitService: rlSvc,
}
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var handleErrorCalled bool
var gotFailover bool
ctx := context.Background()
statusCode := tt.statusCode
respBody := tt.respBody
account := tt.account
headers := http.Header{}
if svc.rateLimitService != nil {
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
case ErrorPolicySkipped:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover = false
handleErrorCalled = false
goto verify
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
gotFailover = true
goto verify
}
}
// ErrorPolicyNone → original logic
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
gotFailover = true
}
verify:
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
if tt.expectShouldFailover {
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
svc := &GeminiMessagesCompatService{
rateLimitService: nil,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx := context.Background()
account := &Account{
ID: 300,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if svc.rateLimitService != nil {
t.Fatal("rateLimitService should be nil for this test")
}
// shouldFailoverGeminiUpstreamError still works
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require.NotPanics(t, func() {
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type geminiErrorPolicyRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setRateLimitedCalls int
setTempCalls int
}
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrorCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
r.setRateLimitedCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.setTempCalls++
return nil
}
...@@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false // 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil { if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
} case ErrorPolicySkipped:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) upstreamReqID := resp.Header.Get(requestIDHeader)
if tempMatched { if upstreamReqID == "" {
upstreamReqID := resp.Header.Get(requestIDHeader) upstreamReqID = resp.Header.Get("x-goog-request-id")
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
} }
upstreamDetail = truncateString(string(respBody), maxBytes) return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader) upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" { if upstreamReqID == "" {
...@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting. // This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
...@@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil }, nil
} }
if tempMatched { // 统一错误策略:自定义错误码 + 临时不可调度
evBody := unwrapIfNeeded(isOAuth, respBody) if s.rateLimitService != nil {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) case ErrorPolicySkipped:
upstreamDetail := "" respBody = unwrapIfNeeded(isOAuth, respBody)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { contentType := resp.Header.Get("Content-Type")
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if contentType == "" {
if maxBytes <= 0 { contentType = "application/json"
maxBytes = 2048
} }
upstreamDetail = truncateString(string(evBody), maxBytes) c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody) evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
......
...@@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont ...@@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil return nil
} }
...@@ -269,30 +266,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, ...@@ -269,30 +266,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil return nil
} }
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background() ctx := context.Background()
......
...@@ -6,26 +6,11 @@ import ( ...@@ -6,26 +6,11 @@ import (
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
) )
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符) // shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20% // XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
func shortHash(data []byte) string { func shortHash(data []byte) string {
...@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m ...@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
return base64.RawURLEncoding.EncodeToString(hash[:12]) return base64.RawURLEncoding.EncodeToString(hash[:12])
} }
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值 // ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID} // 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) { func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
...@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string { ...@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀 // geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:" const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey // GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话 // 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
......
package service package service
import ( import (
"context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
) )
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配 // TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) { func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache() store := NewDigestSessionStore()
groupID := int64(1) groupID := int64(1)
prefixHash := "test_prefix_hash" prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345" sessionUUID := "session-uuid-12345"
...@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { ...@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 1 chain: %s", chain1) t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话 // 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1) _, _, _, found := store.Find(groupID, prefixHash, chain1)
if found { if found {
t.Error("Round 1: should not find existing session") t.Error("Round 1: should not find existing session")
} }
// 保存第一轮会话 // 保存第一轮会话(首轮无旧 chain)
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID) store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
// 模拟第二轮对话(用户继续对话) // 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{ req2 := &antigravity.GeminiRequest{
...@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { ...@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 2 chain: %s", chain2) t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配) // 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2) foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
if !found { if !found {
t.Error("Round 2: should find session via prefix matching") t.Error("Round 2: should find session via prefix matching")
} }
...@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { ...@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID) t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
} }
// 保存第二轮会话 // 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID) store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
// 模拟第三轮对话 // 模拟第三轮对话
req3 := &antigravity.GeminiRequest{ req3 := &antigravity.GeminiRequest{
...@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { ...@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 3 chain: %s", chain3) t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配) // 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3) foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
if !found { if !found {
t.Error("Round 3: should find session via prefix matching") t.Error("Round 3: should find session via prefix matching")
} }
...@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { ...@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
if foundAccID != accountID { if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID) t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
} }
t.Log("✓ Continuous conversation session matching works correctly!")
} }
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配 // TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) { func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache() store := NewDigestSessionStore()
groupID := int64(1) groupID := int64(1)
prefixHash := "test_prefix_hash" prefixHash := "test_prefix_hash"
...@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { ...@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
}, },
} }
chain1 := BuildGeminiDigestChain(req1) chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100) store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
// 第二个完全不同的会话 // 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{ req2 := &antigravity.GeminiRequest{
...@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { ...@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
chain2 := BuildGeminiDigestChain(req2) chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配 // 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2) _, _, _, found := store.Find(groupID, prefixHash, chain2)
if found { if found {
t.Error("Different conversations should not match") t.Error("Different conversations should not match")
} }
t.Log("✓ Different conversations are correctly isolated!")
} }
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先) // TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) { func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache() store := NewDigestSessionStore()
groupID := int64(1) groupID := int64(1)
prefixHash := "test_prefix_hash" prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号 // 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1 store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1) store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
// 第二轮 -> 账号 2 store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3 // 查找更长的链,应该返回最长匹配(账号 3)
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3) _, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
// 查找应该返回最长匹配(账号 3)
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
if !found { if !found {
t.Error("Should find session") t.Error("Should find session")
} }
if accID != 3 { if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID) t.Errorf("Should match longest prefix (account 3), got account %d", accID)
} }
t.Log("✓ Longest prefix matching works correctly!")
} }
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background
...@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { ...@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
} }
} }
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) { func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
...@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) { ...@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
} }
}) })
} }
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ 基础优先级测试 ============
func TestGenerateSessionHash_NilParsedRequest(t *testing.T) {
svc := &GatewayService{}
require.Empty(t, svc.GenerateSessionHash(nil))
}
func TestGenerateSessionHash_EmptyRequest(t *testing.T) {
svc := &GatewayService{}
require.Empty(t, svc.GenerateSessionHash(&ParsedRequest{}))
}
func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, "metadata session_id should have highest priority")
}
// ============ System + Messages 基础测试 ============
func TestGenerateSessionHash_SystemPlusMessages(t *testing.T) {
svc := &GatewayService{}
withSystem := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
withoutSystem := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
h1 := svc.GenerateSessionHash(withSystem)
h2 := svc.GenerateSessionHash(withoutSystem)
require.NotEmpty(t, h1)
require.NotEmpty(t, h2)
require.NotEqual(t, h1, h2, "system prompt should be part of digest, producing different hash")
}
func TestGenerateSessionHash_SystemOnlyProducesHash(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
}
hash := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, hash, "system prompt alone should produce a hash as part of full digest")
}
func TestGenerateSessionHash_DifferentSystemsSameMessages(t *testing.T) {
svc := &GatewayService{}
parsed1 := &ParsedRequest{
System: "You are assistant A.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
parsed2 := &ParsedRequest{
System: "You are assistant B.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
h1 := svc.GenerateSessionHash(parsed1)
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h1, h2, "different system prompts with same messages should produce different hashes")
}
func TestGenerateSessionHash_SameSystemSameMessages(t *testing.T) {
svc := &GatewayService{}
mk := func() *ParsedRequest {
return &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi"},
},
}
}
h1 := svc.GenerateSessionHash(mk())
h2 := svc.GenerateSessionHash(mk())
require.Equal(t, h1, h2, "same system + same messages should produce identical hash")
}
func TestGenerateSessionHash_DifferentMessagesProduceDifferentHash(t *testing.T) {
svc := &GatewayService{}
parsed1 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "help me with Go"},
},
}
parsed2 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "help me with Python"},
},
}
h1 := svc.GenerateSessionHash(parsed1)
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h1, h2, "same system but different messages should produce different hashes")
}
// ============ SessionContext 核心测试 ============
func TestGenerateSessionHash_DifferentSessionContextProducesDifferentHash(t *testing.T) {
svc := &GatewayService{}
// 相同消息 + 不同 SessionContext → 不同 hash(解决碰撞问题的核心场景)
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "192.168.1.1",
UserAgent: "Mozilla/5.0",
APIKeyID: 100,
},
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "10.0.0.1",
UserAgent: "curl/7.0",
APIKeyID: 200,
},
}
h1 := svc.GenerateSessionHash(parsed1)
h2 := svc.GenerateSessionHash(parsed2)
require.NotEmpty(t, h1)
require.NotEmpty(t, h2)
require.NotEqual(t, h1, h2, "same messages but different SessionContext should produce different hashes")
}
func TestGenerateSessionHash_SameSessionContextProducesSameHash(t *testing.T) {
svc := &GatewayService{}
mk := func() *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "192.168.1.1",
UserAgent: "Mozilla/5.0",
APIKeyID: 100,
},
}
}
h1 := svc.GenerateSessionHash(mk())
h2 := svc.GenerateSessionHash(mk())
require.Equal(t, h1, h2, "same messages + same SessionContext should produce identical hash")
}
func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "192.168.1.1",
UserAgent: "Mozilla/5.0",
APIKeyID: 100,
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash,
"metadata session_id should take priority over SessionContext")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{}
withCtx := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: nil,
}
withoutCtx := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
h1 := svc.GenerateSessionHash(withCtx)
h2 := svc.GenerateSessionHash(withoutCtx)
require.Equal(t, h1, h2, "nil SessionContext should produce same hash as no SessionContext")
}
// ============ 多轮连续会话测试 ============
func TestGenerateSessionHash_ContinuousConversation_HashChangesWithMessages(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 模拟连续会话:每增加一轮对话,hash 应该不同(内容累积变化)
round1 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: ctx,
}
round2 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "Hi there!"},
map[string]any{"role": "user", "content": "How are you?"},
},
SessionContext: ctx,
}
round3 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "Hi there!"},
map[string]any{"role": "user", "content": "How are you?"},
map[string]any{"role": "assistant", "content": "I'm doing well!"},
map[string]any{"role": "user", "content": "Tell me a joke"},
},
SessionContext: ctx,
}
h1 := svc.GenerateSessionHash(round1)
h2 := svc.GenerateSessionHash(round2)
h3 := svc.GenerateSessionHash(round3)
require.NotEmpty(t, h1)
require.NotEmpty(t, h2)
require.NotEmpty(t, h3)
require.NotEqual(t, h1, h2, "different conversation rounds should produce different hashes")
require.NotEqual(t, h2, h3, "each new round should produce a different hash")
require.NotEqual(t, h1, h3, "round 1 and round 3 should differ")
}
func TestGenerateSessionHash_ContinuousConversation_SameRoundSameHash(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 同一轮对话重复请求(如重试)应产生相同 hash
mk := func() *ParsedRequest {
return &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "Hi there!"},
map[string]any{"role": "user", "content": "How are you?"},
},
SessionContext: ctx,
}
}
h1 := svc.GenerateSessionHash(mk())
h2 := svc.GenerateSessionHash(mk())
require.Equal(t, h1, h2, "same conversation state should produce identical hash on retry")
}
// ============ 消息回退测试 ============
func TestGenerateSessionHash_MessageRollback(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 模拟消息回退:用户删掉最后一轮再重发
original := &ParsedRequest{
System: "System prompt",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "msg1"},
map[string]any{"role": "assistant", "content": "reply1"},
map[string]any{"role": "user", "content": "msg2"},
map[string]any{"role": "assistant", "content": "reply2"},
map[string]any{"role": "user", "content": "msg3"},
},
SessionContext: ctx,
}
// 回退到 msg2 后,用新的 msg3 替代
rollback := &ParsedRequest{
System: "System prompt",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "msg1"},
map[string]any{"role": "assistant", "content": "reply1"},
map[string]any{"role": "user", "content": "msg2"},
map[string]any{"role": "assistant", "content": "reply2"},
map[string]any{"role": "user", "content": "different msg3"},
},
SessionContext: ctx,
}
hOrig := svc.GenerateSessionHash(original)
hRollback := svc.GenerateSessionHash(rollback)
require.NotEqual(t, hOrig, hRollback, "rollback with different last message should produce different hash")
}
func TestGenerateSessionHash_MessageRollbackSameContent(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 回退后重新发送相同内容 → 相同 hash(合理的粘性恢复)
mk := func() *ParsedRequest {
return &ParsedRequest{
System: "System prompt",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "msg1"},
map[string]any{"role": "assistant", "content": "reply1"},
map[string]any{"role": "user", "content": "msg2"},
},
SessionContext: ctx,
}
}
h1 := svc.GenerateSessionHash(mk())
h2 := svc.GenerateSessionHash(mk())
require.Equal(t, h1, h2, "rollback and resend same content should produce same hash")
}
// ============ 相同 System、不同用户消息 ============
func TestGenerateSessionHash_SameSystemDifferentUsers(t *testing.T) {
svc := &GatewayService{}
// 两个不同用户使用相同 system prompt 但发送不同消息
user1 := &ParsedRequest{
System: "You are a code reviewer.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "Review this Go code"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: "vscode",
APIKeyID: 1,
},
}
user2 := &ParsedRequest{
System: "You are a code reviewer.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "Review this Python code"},
},
SessionContext: &SessionContext{
ClientIP: "2.2.2.2",
UserAgent: "vscode",
APIKeyID: 2,
},
}
h1 := svc.GenerateSessionHash(user1)
h2 := svc.GenerateSessionHash(user2)
require.NotEqual(t, h1, h2, "different users with different messages should get different hashes")
}
func TestGenerateSessionHash_SameSystemSameMessageDifferentContext(t *testing.T) {
svc := &GatewayService{}
// 这是修复的核心场景:两个不同用户发送完全相同的 system + messages(如 "hello")
// 有了 SessionContext 后应该产生不同 hash
user1 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: "Mozilla/5.0",
APIKeyID: 10,
},
}
user2 := &ParsedRequest{
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "2.2.2.2",
UserAgent: "Mozilla/5.0",
APIKeyID: 20,
},
}
h1 := svc.GenerateSessionHash(user1)
h2 := svc.GenerateSessionHash(user2)
require.NotEqual(t, h1, h2, "CRITICAL: same system+messages but different users should get different hashes")
}
// ============ SessionContext 各字段独立影响测试 ============
func TestGenerateSessionHash_SessionContext_IPDifference(t *testing.T) {
svc := &GatewayService{}
base := func(ip string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: ip,
UserAgent: "same-ua",
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("1.1.1.1"))
h2 := svc.GenerateSessionHash(base("2.2.2.2"))
require.NotEqual(t, h1, h2, "different IP should produce different hash")
}
func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Mozilla/5.0"))
h2 := svc.GenerateSessionHash(base("curl/7.0"))
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
}
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
svc := &GatewayService{}
base := func(keyID int64) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: "same-ua",
APIKeyID: keyID,
},
}
}
h1 := svc.GenerateSessionHash(base(1))
h2 := svc.GenerateSessionHash(base(2))
require.NotEqual(t, h1, h2, "different APIKeyID should produce different hash")
}
// ============ 多用户并发相同消息场景 ============
func TestGenerateSessionHash_MultipleUsersSameFirstMessage(t *testing.T) {
svc := &GatewayService{}
// 模拟 5 个不同用户同时发送 "hello" → 应该产生 5 个不同的 hash
hashes := make(map[string]bool)
for i := 0; i < 5; i++ {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "192.168.1." + string(rune('1'+i)),
UserAgent: "client-" + string(rune('A'+i)),
APIKeyID: int64(i + 1),
},
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h)
require.False(t, hashes[h], "hash collision detected for user %d", i)
hashes[h] = true
}
require.Len(t, hashes, 5, "5 different users should produce 5 unique hashes")
}
// ============ 连续会话粘性:多轮对话同一用户 ============
func TestGenerateSessionHash_SameUserGrowingConversation(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "browser", APIKeyID: 42}
// 模拟同一用户的连续会话,每轮 hash 不同但同用户重试保持一致
messages := []map[string]any{
{"role": "user", "content": "msg1"},
{"role": "assistant", "content": "reply1"},
{"role": "user", "content": "msg2"},
{"role": "assistant", "content": "reply2"},
{"role": "user", "content": "msg3"},
{"role": "assistant", "content": "reply3"},
{"role": "user", "content": "msg4"},
}
prevHash := ""
for round := 1; round <= len(messages); round += 2 {
// 构建前 round 条消息
msgs := make([]any, round)
for j := 0; j < round; j++ {
msgs[j] = messages[j]
}
parsed := &ParsedRequest{
System: "System",
HasSystem: true,
Messages: msgs,
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "round %d hash should not be empty", round)
if prevHash != "" {
require.NotEqual(t, prevHash, h, "round %d hash should differ from previous round", round)
}
prevHash = h
// 同一轮重试应该相同
h2 := svc.GenerateSessionHash(parsed)
require.Equal(t, h, h2, "retry of round %d should produce same hash", round)
}
}
// ============ 多轮消息内容结构化测试 ============
func TestGenerateSessionHash_MultipleUserMessages(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 5 条用户消息(无 assistant 回复)
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "first"},
map[string]any{"role": "user", "content": "second"},
map[string]any{"role": "user", "content": "third"},
map[string]any{"role": "user", "content": "fourth"},
map[string]any{"role": "user", "content": "fifth"},
},
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h)
// 修改中间一条消息应该改变 hash
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "first"},
map[string]any{"role": "user", "content": "CHANGED"},
map[string]any{"role": "user", "content": "third"},
map[string]any{"role": "user", "content": "fourth"},
map[string]any{"role": "user", "content": "fifth"},
},
SessionContext: ctx,
}
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h, h2, "changing any message should change the hash")
}
func TestGenerateSessionHash_MessageOrderMatters(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "alpha"},
map[string]any{"role": "user", "content": "beta"},
},
SessionContext: ctx,
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "beta"},
map[string]any{"role": "user", "content": "alpha"},
},
SessionContext: ctx,
}
h1 := svc.GenerateSessionHash(parsed1)
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h1, h2, "message order should affect the hash")
}
// ============ 复杂内容格式测试 ============
func TestGenerateSessionHash_StructuredContent(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 结构化 content(数组形式)
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "Look at this"},
map[string]any{"type": "text", "text": "And this too"},
},
},
},
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "structured content should produce a hash")
}
func TestGenerateSessionHash_ArraySystemPrompt(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 数组格式的 system prompt
parsed := &ParsedRequest{
System: []any{
map[string]any{"type": "text", "text": "You are a helpful assistant."},
map[string]any{"type": "text", "text": "Be concise."},
},
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "array system prompt should produce a hash")
}
// ============ SessionContext 与 cache_control 优先级 ============
func TestGenerateSessionHash_CacheControlOverridesSessionContext(t *testing.T) {
svc := &GatewayService{}
// 当有 cache_control: ephemeral 时,使用第 2 级优先级
// SessionContext 不应影响结果
parsed1 := &ParsedRequest{
System: []any{
map[string]any{
"type": "text",
"text": "You are a tool-specific assistant.",
"cache_control": map[string]any{"type": "ephemeral"},
},
},
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: "ua1",
APIKeyID: 100,
},
}
parsed2 := &ParsedRequest{
System: []any{
map[string]any{
"type": "text",
"text": "You are a tool-specific assistant.",
"cache_control": map[string]any{"type": "ephemeral"},
},
},
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
SessionContext: &SessionContext{
ClientIP: "2.2.2.2",
UserAgent: "ua2",
APIKeyID: 200,
},
}
h1 := svc.GenerateSessionHash(parsed1)
h2 := svc.GenerateSessionHash(parsed2)
require.Equal(t, h1, h2, "cache_control ephemeral has higher priority, SessionContext should not affect result")
}
// ============ 边界情况 ============
func TestGenerateSessionHash_EmptyMessages(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Messages: []any{},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: "test",
APIKeyID: 1,
},
}
// 空 messages + 只有 SessionContext 时,combined.Len() > 0 因为有 context 写入
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "empty messages with SessionContext should still produce a hash from context")
}
func TestGenerateSessionHash_EmptyMessagesNoContext(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Messages: []any{},
}
h := svc.GenerateSessionHash(parsed)
require.Empty(t, h, "empty messages without SessionContext should produce empty hash")
}
func TestGenerateSessionHash_SessionContextWithEmptyFields(t *testing.T) {
svc := &GatewayService{}
// SessionContext 字段为空字符串和零值时仍应影响 hash
withEmptyCtx := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "",
UserAgent: "",
APIKeyID: 0,
},
}
withoutCtx := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
}
h1 := svc.GenerateSessionHash(withEmptyCtx)
h2 := svc.GenerateSessionHash(withoutCtx)
// 有 SessionContext(即使字段为空)仍然会写入分隔符 "::" 等
require.NotEqual(t, h1, h2, "empty-field SessionContext should still differ from nil SessionContext")
}
// ============ 长对话历史测试 ============
func TestGenerateSessionHash_LongConversation(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1}
// 构建 20 轮对话
messages := make([]any, 0, 40)
for i := 0; i < 20; i++ {
messages = append(messages, map[string]any{
"role": "user",
"content": "user message " + string(rune('A'+i)),
})
messages = append(messages, map[string]any{
"role": "assistant",
"content": "assistant reply " + string(rune('A'+i)),
})
}
parsed := &ParsedRequest{
System: "System prompt",
HasSystem: true,
Messages: messages,
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h)
// 再加一轮应该不同
moreMessages := make([]any, len(messages)+2)
copy(moreMessages, messages)
moreMessages[len(messages)] = map[string]any{"role": "user", "content": "one more"}
moreMessages[len(messages)+1] = map[string]any{"role": "assistant", "content": "ok"}
parsed2 := &ParsedRequest{
System: "System prompt",
HasSystem: true,
Messages: moreMessages,
SessionContext: ctx,
}
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash")
}
// ============ Gemini 原生格式 session hash 测试 ============
func TestGenerateSessionHash_GeminiContentsProducesHash(t *testing.T) {
svc := &GatewayService{}
// Gemini 格式: contents[].parts[].text
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Hello from Gemini"},
},
},
},
SessionContext: &SessionContext{
ClientIP: "1.2.3.4",
UserAgent: "gemini-cli",
APIKeyID: 1,
},
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "Gemini contents with parts should produce a non-empty hash")
}
func TestGenerateSessionHash_GeminiDifferentContentsDifferentHash(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1}
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Hello"},
},
},
},
SessionContext: ctx,
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Goodbye"},
},
},
},
SessionContext: ctx,
}
h1 := svc.GenerateSessionHash(parsed1)
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h1, h2, "different Gemini contents should produce different hashes")
}
func TestGenerateSessionHash_GeminiSameContentsSameHash(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1}
mk := func() *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Hello"},
},
},
map[string]any{
"role": "model",
"parts": []any{
map[string]any{"text": "Hi there!"},
},
},
},
SessionContext: ctx,
}
}
h1 := svc.GenerateSessionHash(mk())
h2 := svc.GenerateSessionHash(mk())
require.Equal(t, h1, h2, "same Gemini contents should produce identical hash")
}
func TestGenerateSessionHash_GeminiMultiTurnHashChanges(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1}
round1 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "hello"}},
},
},
SessionContext: ctx,
}
round2 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "hello"}},
},
map[string]any{
"role": "model",
"parts": []any{map[string]any{"text": "Hi!"}},
},
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "How are you?"}},
},
},
SessionContext: ctx,
}
h1 := svc.GenerateSessionHash(round1)
h2 := svc.GenerateSessionHash(round2)
require.NotEmpty(t, h1)
require.NotEmpty(t, h2)
require.NotEqual(t, h1, h2, "Gemini multi-turn should produce different hashes per round")
}
func TestGenerateSessionHash_GeminiDifferentUsersSameContentDifferentHash(t *testing.T) {
svc := &GatewayService{}
// 核心场景:两个不同用户发送相同 Gemini 格式消息应得到不同 hash
user1 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "hello"}},
},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: "gemini-cli",
APIKeyID: 10,
},
}
user2 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "hello"}},
},
},
SessionContext: &SessionContext{
ClientIP: "2.2.2.2",
UserAgent: "gemini-cli",
APIKeyID: 20,
},
}
h1 := svc.GenerateSessionHash(user1)
h2 := svc.GenerateSessionHash(user2)
require.NotEqual(t, h1, h2, "CRITICAL: different Gemini users with same content must get different hashes")
}
func TestGenerateSessionHash_GeminiSystemInstructionAffectsHash(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1}
// systemInstruction 经 ParseGatewayRequest 解析后存入 parsed.System
withSys := &ParsedRequest{
System: []any{
map[string]any{"text": "You are a coding assistant."},
},
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "hello"}},
},
},
SessionContext: ctx,
}
withoutSys := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": "hello"}},
},
},
SessionContext: ctx,
}
h1 := svc.GenerateSessionHash(withSys)
h2 := svc.GenerateSessionHash(withoutSys)
require.NotEqual(t, h1, h2, "systemInstruction should affect the hash")
}
func TestGenerateSessionHash_GeminiMultiPartMessage(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1}
// 多 parts 的消息
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Part 1"},
map[string]any{"text": "Part 2"},
map[string]any{"text": "Part 3"},
},
},
},
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "multi-part Gemini message should produce a hash")
// 不同内容的多 parts
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Part 1"},
map[string]any{"text": "CHANGED"},
map[string]any{"text": "Part 3"},
},
},
},
SessionContext: ctx,
}
h2 := svc.GenerateSessionHash(parsed2)
require.NotEqual(t, h, h2, "changing a part should change the hash")
}
func TestGenerateSessionHash_GeminiNonTextPartsIgnored(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1}
// 含非 text 类型 parts(如 inline_data),应被跳过但不报错
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{"text": "Describe this image"},
map[string]any{"inline_data": map[string]any{"mime_type": "image/png", "data": "base64..."}},
},
},
},
SessionContext: ctx,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "Gemini message with mixed parts should still produce a hash from text parts")
}
func TestGenerateSessionHash_GeminiMultiTurnHashNotSticky(t *testing.T) {
svc := &GatewayService{}
ctx := &SessionContext{ClientIP: "10.0.0.1", UserAgent: "gemini-cli", APIKeyID: 42}
// 模拟同一 Gemini 会话的三轮请求,每轮 contents 累积增长。
// 验证预期行为:每轮 hash 都不同,即 GenerateSessionHash 不具备跨轮粘性。
// 这是 by-design 的——Gemini 的跨轮粘性由 Digest Fallback(BuildGeminiDigestChain)负责。
round1Body := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a coding assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Write a Go function"}]}
]
}`)
round2Body := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a coding assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Write a Go function"}]},
{"role": "model", "parts": [{"text": "func hello() {}"}]},
{"role": "user", "parts": [{"text": "Add error handling"}]}
]
}`)
round3Body := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a coding assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Write a Go function"}]},
{"role": "model", "parts": [{"text": "func hello() {}"}]},
{"role": "user", "parts": [{"text": "Add error handling"}]},
{"role": "model", "parts": [{"text": "func hello() error { return nil }"}]},
{"role": "user", "parts": [{"text": "Now add tests"}]}
]
}`)
hashes := make([]string, 3)
for i, body := range [][]byte{round1Body, round2Body, round3Body} {
parsed, err := ParseGatewayRequest(body, "gemini")
require.NoError(t, err)
parsed.SessionContext = ctx
hashes[i] = svc.GenerateSessionHash(parsed)
require.NotEmpty(t, hashes[i], "round %d hash should not be empty", i+1)
}
// 每轮 hash 都不同——这是预期行为
require.NotEqual(t, hashes[0], hashes[1], "round 1 vs 2 hash should differ (contents grow)")
require.NotEqual(t, hashes[1], hashes[2], "round 2 vs 3 hash should differ (contents grow)")
require.NotEqual(t, hashes[0], hashes[2], "round 1 vs 3 hash should differ")
// 同一轮重试应产生相同 hash
parsed1Again, err := ParseGatewayRequest(round2Body, "gemini")
require.NoError(t, err)
parsed1Again.SessionContext = ctx
h2Again := svc.GenerateSessionHash(parsed1Again)
require.Equal(t, hashes[1], h2Again, "retry of same round should produce same hash")
}
func TestGenerateSessionHash_GeminiEndToEnd(t *testing.T) {
svc := &GatewayService{}
// 端到端测试:模拟 ParseGatewayRequest + GenerateSessionHash 完整流程
body := []byte(`{
"model": "gemini-2.5-pro",
"systemInstruction": {
"parts": [{"text": "You are a coding assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Write a Go function"}]},
{"role": "model", "parts": [{"text": "Here is a function..."}]},
{"role": "user", "parts": [{"text": "Now add error handling"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, "gemini")
require.NoError(t, err)
parsed.SessionContext = &SessionContext{
ClientIP: "10.0.0.1",
UserAgent: "gemini-cli/1.0",
APIKeyID: 42,
}
h := svc.GenerateSessionHash(parsed)
require.NotEmpty(t, h, "end-to-end Gemini flow should produce a hash")
// 同一请求再次解析应产生相同 hash
parsed2, err := ParseGatewayRequest(body, "gemini")
require.NoError(t, err)
parsed2.SessionContext = &SessionContext{
ClientIP: "10.0.0.1",
UserAgent: "gemini-cli/1.0",
APIKeyID: 42,
}
h2 := svc.GenerateSessionHash(parsed2)
require.Equal(t, h, h2, "same request should produce same hash")
// 不同用户发送相同请求应产生不同 hash
parsed3, err := ParseGatewayRequest(body, "gemini")
require.NoError(t, err)
parsed3.SessionContext = &SessionContext{
ClientIP: "10.0.0.2",
UserAgent: "gemini-cli/1.0",
APIKeyID: 99,
}
h3 := svc.GenerateSessionHash(parsed3)
require.NotEqual(t, h, h3, "different user with same Gemini request should get different hash")
}
...@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) { ...@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) {
} }
} }
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "non-antigravity platform",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "claude scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "gemini_text scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"gemini_text": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "gemini-3-flash",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "expired scope rate limit",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "unsupported model",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "gpt-4",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetRateLimitRemainingTime(t *testing.T) { func TestGetRateLimitRemainingTime(t *testing.T) {
now := time.Now() now := time.Now()
future15m := now.Add(15 * time.Minute).Format(time.RFC3339) future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
...@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) { ...@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
maxExpected: 0, maxExpected: 0,
}, },
{ {
name: "model remaining > scope remaining - returns model", name: "model rate limited - 15 minutes",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "scope remaining > model remaining - returns scope",
account: &Account{ account: &Account{
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
Extra: map[string]any{ Extra: map[string]any{
modelRateLimitsKey: map[string]any{ modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{ "claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟 "rate_limit_reset_at": future15m,
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
}, },
}, },
}, },
}, },
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 minExpected: 14 * time.Minute,
maxExpected: 16 * time.Minute, maxExpected: 16 * time.Minute,
}, },
{ {
...@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) { ...@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
minExpected: 4 * time.Minute, minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute, maxExpected: 6 * time.Minute,
}, },
{
name: "only scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{ {
name: "neither rate limited", name: "neither rate limited",
account: &Account{ account: &Account{
......
...@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
} }
} }
} else { } else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad var available []accountWithLoad
for _, acc := range candidates { for _, acc := range candidates {
loadInfo := loadMap[acc.ID] loadInfo := loadMap[acc.ID]
...@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
return a.account.LastUsedAt.Before(*b.account.LastUsedAt) return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
} }
}) })
shuffleWithinSortGroups(available)
for _, item := range available { for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
......
...@@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i ...@@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
return nil return nil
} }
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now() now := time.Now()
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
......
...@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
} }
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" { if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok { if _, ok := platform[acc.Platform]; !ok {
...@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError { if hasError {
p.ErrorCount++ p.ErrorCount++
} }
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
} }
for _, grp := range acc.Groups { for _, grp := range acc.Groups {
...@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError { if hasError {
g.ErrorCount++ g.ErrorCount++
} }
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
} }
displayGroupID := int64(0) displayGroupID := int64(0)
...@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi ...@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec item.RateLimitRemainingSec = &remainingSec
} }
} }
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil { if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds()) remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment