Commit fd43be8d authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突



- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parents 792bef61 836ba14b
......@@ -6,6 +6,7 @@ import (
"sort"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
......@@ -60,8 +61,11 @@ func NewErrorPassthroughService(
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
if err := svc.reloadRulesFromDB(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
}
}
// 订阅缓存更新通知
......@@ -98,7 +102,9 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP
}
// 刷新缓存
s.invalidateAndNotify(ctx)
refreshCtx, cancel := s.newCacheRefreshContext()
defer cancel()
s.invalidateAndNotify(refreshCtx)
return created, nil
}
......@@ -115,7 +121,9 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP
}
// 刷新缓存
s.invalidateAndNotify(ctx)
refreshCtx, cancel := s.newCacheRefreshContext()
defer cancel()
s.invalidateAndNotify(refreshCtx)
return updated, nil
}
......@@ -127,7 +135,9 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
}
// 刷新缓存
s.invalidateAndNotify(ctx)
refreshCtx, cancel := s.newCacheRefreshContext()
defer cancel()
s.invalidateAndNotify(refreshCtx)
return nil
}
......@@ -189,7 +199,12 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
}
}
// 从数据库加载(repo.List 已按 priority 排序)
return s.reloadRulesFromDB(ctx)
}
// 从数据库加载(repo.List 已按 priority 排序)
// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。
func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
rules, err := s.repo.List(ctx)
if err != nil {
return err
......@@ -222,11 +237,32 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR
s.localCacheMu.Unlock()
}
// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。
func (s *ErrorPassthroughService) clearLocalCache() {
s.localCacheMu.Lock()
s.localCache = nil
s.localCacheMu.Unlock()
}
// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。
func (s *ErrorPassthroughService) newCacheRefreshContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 3*time.Second)
}
// invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 先失效缓存,避免后续刷新读到陈旧规则。
if s.cache != nil {
if err := s.cache.Invalidate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err)
}
}
// 刷新本地缓存
if err := s.refreshLocalCache(ctx); err != nil {
if err := s.reloadRulesFromDB(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
s.clearLocalCache()
}
// 通知其他实例
......
......@@ -4,6 +4,7 @@ package service
import (
"context"
"errors"
"strings"
"testing"
......@@ -15,13 +16,80 @@ import (
// mockErrorPassthroughRepo 用于测试的 mock repository
type mockErrorPassthroughRepo struct {
rules []*model.ErrorPassthroughRule
listErr error
getErr error
createErr error
updateErr error
deleteErr error
}
type mockErrorPassthroughCache struct {
rules []*model.ErrorPassthroughRule
hasData bool
getCalled int
setCalled int
invalidateCalled int
notifyCalled int
}
func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache {
return &mockErrorPassthroughCache{
rules: cloneRules(rules),
hasData: hasData,
}
}
func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
m.getCalled++
if !m.hasData {
return nil, false
}
return cloneRules(m.rules), true
}
func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
m.setCalled++
m.rules = cloneRules(rules)
m.hasData = true
return nil
}
func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error {
m.invalidateCalled++
m.rules = nil
m.hasData = false
return nil
}
func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error {
m.notifyCalled++
return nil
}
func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
// 单测中无需订阅行为
}
func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule {
if rules == nil {
return nil
}
out := make([]*model.ErrorPassthroughRule, len(rules))
copy(out, rules)
return out
}
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
if m.listErr != nil {
return nil, m.listErr
}
return m.rules, nil
}
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
if m.getErr != nil {
return nil, m.getErr
}
for _, r := range m.rules {
if r.ID == id {
return r, nil
......@@ -31,12 +99,18 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode
}
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if m.createErr != nil {
return nil, m.createErr
}
rule.ID = int64(len(m.rules) + 1)
m.rules = append(m.rules, rule)
return rule, nil
}
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if m.updateErr != nil {
return nil, m.updateErr
}
for i, r := range m.rules {
if r.ID == rule.ID {
m.rules[i] = rule
......@@ -47,6 +121,9 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error
}
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
if m.deleteErr != nil {
return m.deleteErr
}
for i, r := range m.rules {
if r.ID == id {
m.rules = append(m.rules[:i], m.rules[i+1:]...)
......@@ -750,6 +827,158 @@ func TestErrorPassthroughRule_Validate(t *testing.T) {
}
}
// =============================================================================
// 测试写路径缓存刷新(Create/Update/Delete)
// =============================================================================
func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) {
ctx := context.Background()
staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule})
newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败")
created, err := svc.Create(ctx, newRule)
require.NoError(t, err)
require.NotNil(t, created)
body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`)
matched := svc.MatchRule("anthropic", 503, body)
require.NotNil(t, matched)
assert.Equal(t, created.ID, matched.ID)
if assert.NotNil(t, matched.CustomMessage) {
assert.Equal(t, "上游请求失败", *matched.CustomMessage)
}
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.invalidateCalled)
assert.Equal(t, 1, cache.setCalled)
assert.Equal(t, 1, cache.notifyCalled)
}
func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) {
ctx := context.Background()
originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule})
updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息")
_, err := svc.Update(ctx, updatedRule)
require.NoError(t, err)
oldBody := []byte(`{"message":"old keyword"}`)
oldMatched := svc.MatchRule("anthropic", 503, oldBody)
assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中")
newBody := []byte(`{"message":"new keyword"}`)
newMatched := svc.MatchRule("anthropic", 503, newBody)
require.NotNil(t, newMatched)
if assert.NotNil(t, newMatched.CustomMessage) {
assert.Equal(t, "新消息", *newMatched.CustomMessage)
}
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.invalidateCalled)
assert.Equal(t, 1, cache.setCalled)
assert.Equal(t, 1, cache.notifyCalled)
}
func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) {
ctx := context.Background()
rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{rule})
err := svc.Delete(ctx, 1)
require.NoError(t, err)
body := []byte(`{"message":"to be deleted"}`)
matched := svc.MatchRule("anthropic", 503, body)
assert.Nil(t, matched, "删除后规则不应再命中")
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.invalidateCalled)
assert.Equal(t, 1, cache.setCalled)
assert.Equal(t, 1, cache.notifyCalled)
}
func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) {
staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息")
latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息")
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
svc := NewErrorPassthroughService(repo, cache)
matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`))
require.NotNil(t, matchedFresh)
assert.Equal(t, int64(1), matchedFresh.ID)
matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`))
assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存")
assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get")
assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存")
}
func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) {
ctx := context.Background()
staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息")
repo := &mockErrorPassthroughRepo{
rules: []*model.ErrorPassthroughRule{staleRule},
listErr: errors.New("db list failed"),
}
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
svc := &ErrorPassthroughService{repo: repo, cache: cache}
svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule})
disabledRule := *staleRule
disabledRule.Enabled = false
_, err := svc.Update(ctx, &disabledRule)
require.NoError(t, err)
body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`)
matched := svc.MatchRule("anthropic", 503, body)
assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则")
svc.localCacheMu.RLock()
assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中")
svc.localCacheMu.RUnlock()
}
func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule {
responseCode := 503
rule := &model.ErrorPassthroughRule{
ID: id,
Name: "write-path-cache-refresh",
Enabled: true,
Priority: 1,
ErrorCodes: []int{503},
Keywords: []string{keyword},
MatchMode: model.MatchModeAll,
PassthroughCode: false,
ResponseCode: &responseCode,
PassthroughBody: false,
CustomMessage: &customMsg,
}
return rule
}
// Helper functions
func testIntPtr(i int) *int { return &i }
func testStrPtr(s string) *string { return &s }
//go:build unit
package service
import (
"context"
"testing"
)
func TestIsForceCacheBilling(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected bool
}{
{
name: "context without force cache billing",
ctx: context.Background(),
expected: false,
},
{
name: "context with force cache billing set to true",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, true),
expected: true,
},
{
name: "context with force cache billing set to false",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, false),
expected: false,
},
{
name: "context with wrong type value",
ctx: context.WithValue(context.Background(), ForceCacheBillingContextKey, "true"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsForceCacheBilling(tt.ctx)
if result != tt.expected {
t.Errorf("IsForceCacheBilling() = %v, want %v", result, tt.expected)
}
})
}
}
func TestWithForceCacheBilling(t *testing.T) {
ctx := context.Background()
// 原始上下文没有标记
if IsForceCacheBilling(ctx) {
t.Error("original context should not have force cache billing")
}
// 使用 WithForceCacheBilling 后应该有标记
newCtx := WithForceCacheBilling(ctx)
if !IsForceCacheBilling(newCtx) {
t.Error("new context should have force cache billing")
}
// 原始上下文应该不受影响
if IsForceCacheBilling(ctx) {
t.Error("original context should still not have force cache billing")
}
}
func TestForceCacheBilling_TokenConversion(t *testing.T) {
tests := []struct {
name string
forceCacheBilling bool
inputTokens int
cacheReadInputTokens int
expectedInputTokens int
expectedCacheReadTokens int
}{
{
name: "force cache billing converts input to cache_read",
forceCacheBilling: true,
inputTokens: 1000,
cacheReadInputTokens: 500,
expectedInputTokens: 0,
expectedCacheReadTokens: 1500, // 500 + 1000
},
{
name: "no force cache billing keeps tokens unchanged",
forceCacheBilling: false,
inputTokens: 1000,
cacheReadInputTokens: 500,
expectedInputTokens: 1000,
expectedCacheReadTokens: 500,
},
{
name: "force cache billing with zero input tokens does nothing",
forceCacheBilling: true,
inputTokens: 0,
cacheReadInputTokens: 500,
expectedInputTokens: 0,
expectedCacheReadTokens: 500,
},
{
name: "force cache billing with zero cache_read tokens",
forceCacheBilling: true,
inputTokens: 1000,
cacheReadInputTokens: 0,
expectedInputTokens: 0,
expectedCacheReadTokens: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 RecordUsage 中的 ForceCacheBilling 逻辑
usage := ClaudeUsage{
InputTokens: tt.inputTokens,
CacheReadInputTokens: tt.cacheReadInputTokens,
}
// 这是 RecordUsage 中的实际逻辑
if tt.forceCacheBilling && usage.InputTokens > 0 {
usage.CacheReadInputTokens += usage.InputTokens
usage.InputTokens = 0
}
if usage.InputTokens != tt.expectedInputTokens {
t.Errorf("InputTokens = %d, want %d", usage.InputTokens, tt.expectedInputTokens)
}
if usage.CacheReadInputTokens != tt.expectedCacheReadTokens {
t.Errorf("CacheReadInputTokens = %d, want %d", usage.CacheReadInputTokens, tt.expectedCacheReadTokens)
}
})
}
}
......@@ -219,6 +219,22 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil
}
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
......@@ -335,7 +351,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity)
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAntigravity)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
......@@ -673,7 +689,7 @@ func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *tes
cfg: testConfig(),
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-sonnet-4-5", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
......@@ -1017,11 +1033,17 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
expected bool
}{
{
name: "Antigravity平台-支持claude模型",
name: "Antigravity平台-支持默认映射中的claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
model: "claude-sonnet-4-5",
expected: true,
},
{
name: "Antigravity平台-不支持非默认映射中的claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
......@@ -1118,7 +1140,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
......@@ -1126,7 +1148,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) {
groupID := int64(30)
requestedModel := "claude-3-5-sonnet-20241022"
requestedModel := "claude-sonnet-4-5"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
......@@ -1171,7 +1193,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-路由粘性命中", func(t *testing.T) {
groupID := int64(31)
requestedModel := "claude-3-5-sonnet-20241022"
requestedModel := "claude-sonnet-4-5"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
......@@ -1323,7 +1345,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{
"claude-3-5-sonnet-20241022": map[string]any{
"rate_limit_reset_at": resetAt.Format(time.RFC3339),
},
},
......@@ -1468,7 +1490,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
......@@ -1600,7 +1622,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-sonnet-4-5", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
......@@ -1873,6 +1895,19 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil
}
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users {
result[user.ID] = &UserLoadInfo{
UserID: user.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
......@@ -2750,7 +2785,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
Concurrency: 5,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{
"claude-3-5-sonnet-20241022": map[string]any{
"rate_limit_reset_at": now.Format(time.RFC3339),
},
},
......
......@@ -4,6 +4,9 @@ import (
"bytes"
"encoding/json"
"fmt"
"math"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ParsedRequest 保存网关请求的预解析结果
......@@ -26,6 +29,8 @@ type ParsedRequest struct {
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
MaxTokens int // max_tokens 值(用于探测请求拦截)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
......@@ -69,9 +74,62 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.Messages = messages
}
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
parsed.ThinkingEnabled = true
}
}
// max_tokens
if rawMaxTokens, exists := req["max_tokens"]; exists {
if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok {
parsed.MaxTokens = maxTokens
}
}
return parsed, nil
}
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
func parseIntegralNumber(raw any) (int, bool) {
switch v := raw.(type) {
case float64:
if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) {
return 0, false
}
if v > float64(math.MaxInt) || v < float64(math.MinInt) {
return 0, false
}
return int(v), true
case int:
return v, true
case int8:
return int(v), true
case int16:
return int(v), true
case int32:
return int(v), true
case int64:
if v > int64(math.MaxInt) || v < int64(math.MinInt) {
return 0, false
}
return int(v), true
case json.Number:
i64, err := v.Int64()
if err != nil {
return 0, false
}
if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) {
return 0, false
}
return int(i64), true
default:
return 0, false
}
}
// FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
......@@ -466,7 +524,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" {
if signature != "" && signature != antigravity.DummyThoughtSignature {
newContent = append(newContent, block)
continue
}
......
......@@ -17,6 +17,29 @@ func TestParseGatewayRequest(t *testing.T) {
require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1)
require.False(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, 1, parsed.MaxTokens)
}
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, 0, parsed.MaxTokens)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
......
......@@ -12,10 +12,3 @@ func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}
......@@ -49,6 +49,29 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
// IsForceCacheBilling 检查是否启用强制缓存计费
func IsForceCacheBilling(ctx context.Context) bool {
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
return v
}
// WithForceCacheBilling 返回带有强制缓存计费标记的上下文
func WithForceCacheBilling(ctx context.Context) context.Context {
return context.WithValue(ctx, ForceCacheBillingContextKey, true)
}
func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on"
......@@ -207,40 +230,6 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
"read": "Read",
"edit": "Edit",
"write": "Write",
"task": "Task",
"glob": "Glob",
"grep": "Grep",
"webfetch": "WebFetch",
"websearch": "WebSearch",
"todowrite": "TodoWrite",
"question": "AskUserQuestion",
}
openCodeToolOverrides = map[string]string{
"Bash": "bash",
"Read": "read",
"Edit": "edit",
"Write": "write",
"Task": "task",
"Glob": "glob",
"Grep": "grep",
"WebFetch": "webfetch",
"WebSearch": "websearch",
"TodoWrite": "todowrite",
"AskUserQuestion": "question",
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
......@@ -284,6 +273,13 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
......@@ -299,6 +295,24 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息(uuid, accountID)
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
......@@ -309,16 +323,23 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
......@@ -328,6 +349,10 @@ func shouldClearStickySession(account *Account) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
return true
}
return false
}
......@@ -376,6 +401,7 @@ type ForwardResult struct {
type UpstreamFailoverError struct {
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
}
func (e *UpstreamFailoverError) Error() string {
......@@ -508,6 +534,23 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
return accountID, nil
}
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息(uuid, accountID)
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
......@@ -620,71 +663,6 @@ type claudeOAuthNormalizeOptions struct {
stripSystemCacheControl bool
}
func stripToolPrefix(value string) string {
if value == "" {
return value
}
return toolPrefixRe.ReplaceAllString(value, "")
}
func toSnakeCase(value string) string {
if value == "" {
return value
}
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
output = strings.Trim(output, "_")
return strings.ToLower(output)
}
func normalizeToolNameForClaude(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
// 只对已知的工具名进行映射,未知工具名保持原样
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728)
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
return stripped
}
if cache != nil && mapped != stripped {
cache[mapped] = stripped
}
return mapped
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
// 优先从请求时建立的映射中查找
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
// 已知工具名的硬编码映射
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
// 未知工具名保持原样,避免破坏 Anthropic 特殊工具
return stripped
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
if cache != nil {
if mapped, ok := cache[name]; ok {
return mapped
}
}
return name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
......@@ -703,55 +681,6 @@ func sanitizeSystemText(text string) string {
return text
}
func sanitizeToolDescription(description string) string {
if description == "" {
return description
}
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return description
}
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
schema, ok := inputSchema.(map[string]any)
if !ok {
return
}
properties, ok := schema["properties"].(map[string]any)
if !ok {
return
}
newProperties := make(map[string]any, len(properties))
for key, value := range properties {
snakeKey := toSnakeCase(key)
newProperties[snakeKey] = value
if snakeKey != key && cache != nil {
cache[snakeKey] = key
}
}
schema["properties"] = newProperties
if required, ok := schema["required"].([]any); ok {
newRequired := make([]any, 0, len(required))
for _, item := range required {
name, ok := item.(string)
if !ok {
newRequired = append(newRequired, item)
continue
}
snakeName := toSnakeCase(name)
newRequired = append(newRequired, snakeName)
if snakeName != name && cache != nil {
cache[snakeName] = name
}
}
schema["required"] = newRequired
}
}
func stripCacheControlFromSystemBlocks(system any) bool {
blocks, ok := system.([]any)
if !ok {
......@@ -772,24 +701,17 @@ func stripCacheControlFromSystemBlocks(system any) bool {
return changed
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
if len(body) == 0 {
return body, modelID, nil
}
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var reqRaw map[string]json.RawMessage
if err := json.Unmarshal(body, &reqRaw); err != nil {
return body, modelID, nil
return body, modelID
}
// 同时解析为 map[string]any 用于修改非 messages 字段
// 解析为 map[string]any 用于修改字段
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil
return body, modelID
}
toolNameMap := make(map[string]string)
modified := false
if system, ok := req["system"]; ok {
......@@ -831,115 +753,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
if rawTools, exists := req["tools"]; exists {
switch tools := rawTools.(type) {
case []any:
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
if name, ok := toolMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
toolMap["name"] = normalized
modified = true
}
}
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
modified = true
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
modified = true
}
tools[idx] = toolMap
}
req["tools"] = tools
case map[string]any:
normalizedTools := make(map[string]any, len(tools))
for name, value := range tools {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized == "" {
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
toolMap["name"] = normalized
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
}
normalizedTools[normalized] = value
}
req["tools"] = normalizedTools
modified = true
}
} else {
// 确保 tools 字段存在(即使为空数组)
if _, exists := req["tools"]; !exists {
req["tools"] = []any{}
modified = true
}
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified := false
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
// 检查此消息是否包含 thinking 块
hasThinking := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
hasThinking = true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if hasThinking {
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
continue
}
if name, ok := blockMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
blockMap["name"] = normalized
messagesModified = true
}
}
}
}
}
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
......@@ -968,38 +787,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
modified = true
}
if !modified && !messagesModified {
return body, modelID, toolNameMap
if !modified {
return body, modelID
}
// 如果 messages 没有被修改,保留原始 messages 字节
if !messagesModified {
// 序列化非 messages 字段
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
return body, modelID
}
// 替换回原始的 messages
var newReq map[string]json.RawMessage
if err := json.Unmarshal(newBody, &newReq); err != nil {
return newBody, modelID, toolNameMap
}
if origMessages, ok := reqRaw["messages"]; ok {
newReq["messages"] = origMessages
}
finalBody, err := json.Marshal(newReq)
if err != nil {
return newBody, modelID, toolNameMap
}
return finalBody, modelID, toolNameMap
}
// messages 被修改了,需要完整序列化
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
return newBody, modelID, toolNameMap
return newBody, modelID
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
......@@ -1253,6 +1049,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID
for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) {
filteredExcluded++
......@@ -1271,12 +1068,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredPlatform++
continue
}
if !account.IsSchedulableForModel(requestedModel) {
filteredModelScope++
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) {
filteredModelMapping++
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
filteredModelMapping++
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
filteredModelScope++
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
}
// 窗口费用检查(非粘性会话路径)
......@@ -1291,6 +1089,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
if len(modelScopeSkippedIDs) > 0 {
log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v",
derefGroupID(groupID), requestedModel, modelScopeSkippedIDs)
}
}
if len(routingCandidates) > 0 {
......@@ -1302,8 +1104,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
......@@ -1360,10 +1162,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
// 3. 按负载感知排序
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var routingAvailable []accountWithLoad
for _, acc := range routingCandidates {
loadInfo := routingLoadMap[acc.ID]
......@@ -1454,14 +1252,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
account.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
......@@ -1519,10 +1317,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
// 窗口费用检查(非粘性会话路径)
......@@ -1550,10 +1348,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
......@@ -1568,48 +1366,109 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
}
}
......@@ -2025,6 +1884,106 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
// filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minPriority := accounts[0].account.Priority
for _, acc := range accounts[1:] {
if acc.account.Priority < minPriority {
minPriority = acc.account.Priority
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.account.Priority == minPriority {
result = append(result, acc)
}
}
return result
}
// filterByMinLoadRate 过滤出负载率最低的账号集合
func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minLoadRate := accounts[0].loadInfo.LoadRate
for _, acc := range accounts[1:] {
if acc.loadInfo.LoadRate < minLoadRate {
minLoadRate = acc.loadInfo.LoadRate
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.loadInfo.LoadRate == minLoadRate {
result = append(result, acc)
}
}
return result
}
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 1. 找到最小的 LastUsedAt(nil 被视为最小)
var minTime *time.Time
hasNil := false
for _, acc := range accounts {
if acc.account.LastUsedAt == nil {
hasNil = true
break
}
if minTime == nil || acc.account.LastUsedAt.Before(*minTime) {
minTime = acc.account.LastUsedAt
}
}
// 2. 收集所有具有最小 LastUsedAt 的账号索引
var candidateIdxs []int
for i, acc := range accounts {
if hasNil {
if acc.account.LastUsedAt == nil {
candidateIdxs = append(candidateIdxs, i)
}
} else {
if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) {
candidateIdxs = append(candidateIdxs, i)
}
}
}
// 3. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 5. 随机选择一个
selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))]
return &accounts[selectedIdx]
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
......@@ -2047,6 +2006,87 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
......@@ -2128,11 +2168,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
......@@ -2179,10 +2219,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2231,11 +2271,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
......@@ -2271,10 +2311,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2341,11 +2381,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
......@@ -2394,10 +2434,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2446,11 +2486,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
......@@ -2488,10 +2528,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
......@@ -2535,11 +2575,38 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context)
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
if strings.TrimSpace(requestedModel) == "" {
return true
}
// 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底
mapped := mapAntigravityModel(account, requestedModel)
if mapped == "" {
return false
}
// 应用 thinking 后缀后检查最终模型是否在账号映射中
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
finalModel := applyThinkingModelSuffix(mapped, enabled)
if finalModel == mapped {
return true // thinking 后缀未改变模型名,映射已通过
}
return account.IsModelSupported(finalModel)
}
return true
}
return s.isModelSupportedByAccount(account, requestedModel)
}
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
if strings.TrimSpace(requestedModel) == "" {
return true
}
return mapAntigravityModel(account, requestedModel) != ""
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
......@@ -2553,13 +2620,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
......@@ -2964,7 +3024,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
reqModel := parsed.Model
reqStream := parsed.Stream
originalModel := reqModel
var toolNameMap map[string]string
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
......@@ -2988,7 +3047,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
}
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
......@@ -3375,7 +3434,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
......@@ -3388,7 +3447,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
......@@ -3849,6 +3908,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
)
}
// 非 failover 错误也支持错误透传规则匹配。
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
var statusCode int
......@@ -3980,6 +4067,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
)
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
respBody,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed after retries",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
......@@ -4002,7 +4116,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
......@@ -4035,7 +4149,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct {
line string
......@@ -4054,7 +4169,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -4065,7 +4181,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
......@@ -4098,33 +4214,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines := make([]string, 0, 4)
var toolInputBuffers map[int]string
if mimicClaudeCode {
toolInputBuffers = make(map[int]string)
}
transformToolInputJSON := func(raw string) string {
if !mimicClaudeCode {
return raw
}
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
var parsed any
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return replaceToolNamesInText(raw, toolNameMap)
}
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
if changed {
if bytes, err := json.Marshal(rewritten); err == nil {
return string(bytes)
}
}
return raw
}
processSSEEvent := func(lines []string) ([]string, string, error) {
if len(lines) == 0 {
......@@ -4163,16 +4252,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var event map[string]any
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
// JSON 解析失败,直接透传原始数据
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
eventType, _ := event["type"].(string)
......@@ -4202,70 +4288,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
if mimicClaudeCode && eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuffers[index] += partial
}
}
return nil, dataLine, nil
}
}
}
if mimicClaudeCode && eventType == "content_block_stop" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if buffered := toolInputBuffers[index]; buffered != "" {
delete(toolInputBuffers, index)
transformed := transformToolInputJSON(buffered)
synthetic := map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": transformed,
},
}
synthBytes, synthErr := json.Marshal(synthetic)
if synthErr == nil {
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
rewriteToolNamesInValue(event, toolNameMap)
stopBytes, stopErr := json.Marshal(event)
if stopErr == nil {
stopBlock := ""
if eventName != "" {
stopBlock = "event: " + eventName + "\n"
}
stopBlock += "data: " + string(stopBytes) + "\n\n"
return []string{synthBlock, stopBlock}, string(stopBytes), nil
}
}
}
}
}
if mimicClaudeCode {
rewriteToolNamesInValue(event, toolNameMap)
}
newData, err := json.Marshal(event)
if err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
// 序列化失败,直接透传原始数据
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
block := ""
......@@ -4364,126 +4395,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
switch v := value.(type) {
case map[string]any:
changed := false
rewritten := make(map[string]any, len(v))
for key, item := range v {
newKey := normalizeParamNameForOpenCode(key, cache)
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
if newKey != key {
changed = true
}
rewritten[newKey] = newItem
}
if !changed {
return value, false
}
return rewritten, true
case []any:
changed := false
rewritten := make([]any, len(v))
for idx, item := range v {
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
rewritten[idx] = newItem
}
if !changed {
return value, false
}
return rewritten, true
default:
return value, false
}
}
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
changed := false
if blockType, _ := v["type"].(string); blockType == "tool_use" {
if name, ok := v["name"].(string); ok {
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped != name {
v["name"] = mapped
changed = true
}
}
if input, ok := v["input"].(map[string]any); ok {
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
if inputChanged {
if m, ok := rewrittenInput.(map[string]any); ok {
v["input"] = m
changed = true
}
}
}
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
case []any:
changed := false
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
default:
return false
}
}
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
if text == "" {
return text
}
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
submatches := toolNameFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
name := submatches[1]
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped == name {
return match
}
return strings.Replace(match, name, mapped, 1)
})
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
submatches := modelFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
model := submatches[1]
mapped := claude.DenormalizeModelID(model)
if mapped == model {
return match
}
return strings.Replace(match, model, mapped, 1)
})
for mapped, original := range toolNameMap {
if mapped == "" || original == "" || mapped == original {
continue
}
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
}
return output
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens(标准Claude API格式)
var msgStart struct {
......@@ -4527,7 +4438,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
......@@ -4559,9 +4470,6 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
......@@ -4579,46 +4487,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
// replaceModelInResponseBody 替换响应体中的model字段
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
newBody, err := sjson.SetBytes(body, "model", toModel)
if err != nil {
return body
}
return newBody
}
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
if len(body) == 0 {
return body
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
replaced := replaceToolNamesInText(string(body), toolNameMap)
if replaced == string(body) {
return body
}
return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
}
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数
......@@ -4630,6 +4508,7 @@ type RecordUsageInput struct {
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
}
......@@ -4646,6 +4525,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
......@@ -4828,6 +4716,7 @@ type RecordUsageLongContextInput struct {
IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService *APIKeyService // API Key 配额服务(可选)
}
......@@ -4839,6 +4728,15 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
......@@ -5003,7 +4901,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestGatewayService_isModelSupportedByAccount_AntigravityModelMapping(t *testing.T) {
svc := &GatewayService{}
// 使用 model_mapping 作为白名单(通配符匹配)
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
"gemini-3-*": "gemini-3-flash",
},
},
}
// claude-* 通配符匹配
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-opus-4-6"))
// gemini-3-* 通配符匹配
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-pro-high"))
// gemini-2.5-* 不匹配(不在 model_mapping 中)
require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-flash"))
require.False(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro"))
// 其他平台模型不支持
require.False(t, svc.isModelSupportedByAccount(account, "gpt-4"))
// 空模型允许
require.True(t, svc.isModelSupportedByAccount(account, ""))
}
func TestGatewayService_isModelSupportedByAccount_AntigravityNoMapping(t *testing.T) {
svc := &GatewayService{}
// 未配置 model_mapping 时,使用默认映射(domain.DefaultAntigravityModelMapping)
// 只有默认映射中的模型才被支持
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{},
}
// 默认映射中的模型应该被支持
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-3-flash"))
require.True(t, svc.isModelSupportedByAccount(account, "gemini-2.5-pro"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-haiku-4-5"))
// 不在默认映射中的模型不被支持
require.False(t, svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022"))
require.False(t, svc.isModelSupportedByAccount(account, "claude-unknown-model"))
// 非 claude-/gemini- 前缀仍然不支持
require.False(t, svc.isModelSupportedByAccount(account, "gpt-4"))
}
// TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode 测试 thinking 模式下的模型支持检查
// 验证调度时使用映射后的最终模型名(包括 thinking 后缀)来检查 model_mapping 支持
func TestGatewayService_isModelSupportedByAccountWithContext_ThinkingMode(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
modelMapping map[string]any
requestedModel string
thinkingEnabled bool
expected bool
}{
// 场景 1: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=true
// mapAntigravityModel 找不到 claude-sonnet-4-5 的映射 → 返回 false
{
name: "thinking_enabled_no_base_mapping_returns_false",
modelMapping: map[string]any{
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: false,
},
// 场景 2: 只配置 claude-sonnet-4-5-thinking,请求 claude-sonnet-4-5 + thinking=false
// mapAntigravityModel 找不到 claude-sonnet-4-5 的映射 → 返回 false
{
name: "thinking_disabled_no_base_mapping_returns_false",
modelMapping: map[string]any{
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: false,
},
// 场景 3: 配置 claude-sonnet-4-5(非 thinking),请求 claude-sonnet-4-5 + thinking=true
// 最终模型名 = claude-sonnet-4-5-thinking,不在 mapping 中,应该不匹配
{
name: "thinking_enabled_no_match_non_thinking_mapping",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: false,
},
// 场景 4: 配置两种模型,请求 claude-sonnet-4-5 + thinking=true,应该匹配 thinking 版本
{
name: "both_models_thinking_enabled_matches_thinking",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true,
},
// 场景 5: 配置两种模型,请求 claude-sonnet-4-5 + thinking=false,应该匹配非 thinking 版本
{
name: "both_models_thinking_disabled_matches_non_thinking",
modelMapping: map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: true,
},
// 场景 6: 通配符 claude-* 应该同时匹配 thinking 和非 thinking
{
name: "wildcard_matches_thinking",
modelMapping: map[string]any{
"claude-*": "claude-sonnet-4-5",
},
requestedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: true, // claude-sonnet-4-5-thinking 匹配 claude-*
},
// 场景 7: 只配置 thinking 变体但没有基础模型映射 → 返回 false
// mapAntigravityModel 找不到 claude-opus-4-6 的映射
{
name: "opus_thinking_no_base_mapping_returns_false",
modelMapping: map[string]any{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
},
requestedModel: "claude-opus-4-6",
thinkingEnabled: true,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": tt.modelMapping,
},
}
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, tt.thinkingEnabled)
result := svc.isModelSupportedByAccountWithContext(ctx, account, tt.requestedModel)
require.Equal(t, tt.expected, result,
"isModelSupportedByAccountWithContext(ctx[thinking=%v], account, %q) = %v, want %v",
tt.thinkingEnabled, tt.requestedModel, result, tt.expected)
})
}
}
// TestGatewayService_isModelSupportedByAccount_CustomMappingNotInDefault 测试自定义模型映射中
// 不在 DefaultAntigravityModelMapping 中的模型能通过调度
func TestGatewayService_isModelSupportedByAccount_CustomMappingNotInDefault(t *testing.T) {
svc := &GatewayService{}
// 自定义映射中包含不在默认映射中的模型
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"my-custom-model": "actual-upstream-model",
"gpt-4o": "some-upstream-model",
"llama-3-70b": "llama-3-70b-upstream",
"claude-sonnet-4-5": "claude-sonnet-4-5",
},
},
}
// 自定义模型应该通过(不在 DefaultAntigravityModelMapping 中也可以)
require.True(t, svc.isModelSupportedByAccount(account, "my-custom-model"))
require.True(t, svc.isModelSupportedByAccount(account, "gpt-4o"))
require.True(t, svc.isModelSupportedByAccount(account, "llama-3-70b"))
require.True(t, svc.isModelSupportedByAccount(account, "claude-sonnet-4-5"))
// 不在自定义映射中的模型不通过
require.False(t, svc.isModelSupportedByAccount(account, "gpt-3.5-turbo"))
require.False(t, svc.isModelSupportedByAccount(account, "unknown-model"))
// 空模型允许
require.True(t, svc.isModelSupportedByAccount(account, ""))
}
// TestGatewayService_isModelSupportedByAccountWithContext_CustomMappingThinking
// 测试自定义映射 + thinking 模式的交互
func TestGatewayService_isModelSupportedByAccountWithContext_CustomMappingThinking(t *testing.T) {
svc := &GatewayService{}
// 自定义映射同时配置基础模型和 thinking 变体
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"my-custom-model": "upstream-model",
},
},
}
// thinking=true: claude-sonnet-4-5 → mapped=claude-sonnet-4-5 → +thinking → check IsModelSupported(claude-sonnet-4-5-thinking)=true
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "claude-sonnet-4-5"))
// thinking=false: claude-sonnet-4-5 → mapped=claude-sonnet-4-5 → check IsModelSupported(claude-sonnet-4-5)=true
ctx = context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false)
require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "claude-sonnet-4-5"))
// 自定义模型(非 claude)不受 thinking 后缀影响,mapped 成功即通过
ctx = context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
require.True(t, svc.isModelSupportedByAccountWithContext(ctx, account, "my-custom-model"))
}
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
rateLimitService: &RateLimitService{},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
// Minimal SSE event to trigger parseSSEUsage
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n"))
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
}()
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens)
}
......@@ -200,7 +200,7 @@ func (s *GeminiMessagesCompatService) tryStickySessionHit(
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
if shouldClearStickySession(account, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
......@@ -230,7 +230,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false
}
......@@ -362,7 +362,10 @@ func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
if strings.TrimSpace(requestedModel) == "" {
return true
}
return mapAntigravityModel(account, requestedModel) != ""
}
return account.IsModelSupported(requestedModel)
}
......@@ -1498,6 +1501,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformGemini,
upstreamStatus,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus)
}
return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg)
}
var statusCode int
var errType, errMsg string
......@@ -2636,7 +2661,9 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil {
ts := time.Now().Unix() + int64(dur.Seconds())
// Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
// which can affect scheduling decisions around thresholds (like 10s).
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts
}
}
......
......@@ -268,6 +268,22 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil
}
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
......@@ -883,7 +899,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
model: "claude-sonnet-4-5",
expected: true,
},
{
......@@ -892,6 +908,39 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
model: "gpt-4",
expected: false,
},
{
name: "Antigravity平台-空模型允许",
account: &Account{Platform: PlatformAntigravity},
model: "",
expected: true,
},
{
name: "Antigravity平台-自定义映射-支持自定义模型",
account: &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"my-custom-model": "upstream-model",
"gpt-4o": "some-model",
},
},
},
model: "my-custom-model",
expected: true,
},
{
name: "Antigravity平台-自定义映射-不在映射中的模型不支持",
account: &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"my-custom-model": "upstream-model",
},
},
},
model: "claude-sonnet-4-5",
expected: false,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini},
......
package service
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
h := xxhash.Sum64(data)
return strconv.FormatUint(h, 36)
}
// BuildGeminiDigestChain 根据 Gemini 请求生成摘要链
// 格式: s:<hash>-u:<hash>-m:<hash>-u:<hash>-...
// s = systemInstruction, u = user, m = model
func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
if req == nil {
return ""
}
var parts []string
// 1. system instruction
if req.SystemInstruction != nil && len(req.SystemInstruction.Parts) > 0 {
partsData, _ := json.Marshal(req.SystemInstruction.Parts)
parts = append(parts, "s:"+shortHash(partsData))
}
// 2. contents
for _, c := range req.Contents {
prefix := "u" // user
if c.Role == "model" {
prefix = "m"
}
partsData, _ := json.Marshal(c.Parts)
parts = append(parts, prefix+":"+shortHash(partsData))
}
return strings.Join(parts, "-")
}
// GenerateGeminiPrefixHash 生成前缀 hash(用于分区隔离)
// 组合: userID + apiKeyID + ip + userAgent + platform + model
// 返回 16 字符的 Base64 编码的 SHA256 前缀
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
// 组合所有标识符
combined := strconv.FormatInt(userID, 10) + ":" +
strconv.FormatInt(apiKeyID, 10) + ":" +
ip + ":" +
userAgent + ":" +
platform + ":" +
model
hash := sha256.Sum256([]byte(combined))
// 取前 12 字节,Base64 编码后正好 16 字符
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
if value == "" {
return "", 0, false
}
// 找到最后一个 ":" 的位置(因为 uuid 可能包含 ":")
i := strings.LastIndex(value, ":")
if i <= 0 || i >= len(value)-1 {
return "", 0, false
}
uuid = value[:i]
accountID, err := strconv.ParseInt(value[i+1:], 10, 64)
if err != nil {
return "", 0, false
}
return uuid, accountID, true
}
// FormatGeminiSessionValue 格式化 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func FormatGeminiSessionValue(uuid string, accountID int64) string {
return uuid + ":" + strconv.FormatInt(accountID, 10)
}
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
func GenerateGeminiDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return geminiDigestSessionKeyPrefix + prefix + ":" + uuidPart
}
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
accountID := int64(100)
// 模拟第一轮对话
req1 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}},
},
}
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
if foundUUID != sessionUUID {
t.Errorf("Round 2: expected UUID %s, got %s", sessionUUID, foundUUID)
}
if foundAccID != accountID {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Hello, what's your name?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I'm Claude, nice to meet you!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What can you do?"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "I can help with coding, writing, and more!"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Great, help me write some Go code"}}},
},
}
chain3 := BuildGeminiDigestChain(req3)
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
if foundUUID != sessionUUID {
t.Errorf("Round 3: expected UUID %s, got %s", sessionUUID, foundUUID)
}
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 第一个会话
req1 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Tell me about Go programming"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "What's the weather today?"}}},
},
}
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
// 查找应该返回最长匹配(账号 3)
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
func TestShortHash(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"empty", []byte{}},
{"simple", []byte("hello world")},
{"json", []byte(`{"role":"user","parts":[{"text":"hello"}]}`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shortHash(tt.input)
// Base36 编码的 uint64 最长 13 个字符
if len(result) > 13 {
t.Errorf("shortHash result too long: %d characters", len(result))
}
// 相同输入应该产生相同输出
result2 := shortHash(tt.input)
if result != result2 {
t.Errorf("shortHash not deterministic: %s vs %s", result, result2)
}
})
}
}
func TestBuildGeminiDigestChain(t *testing.T) {
tests := []struct {
name string
req *antigravity.GeminiRequest
wantLen int // 预期的分段数量
hasEmpty bool // 是否应该是空字符串
}{
{
name: "nil request",
req: nil,
hasEmpty: true,
},
{
name: "empty contents",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{},
},
hasEmpty: true,
},
{
name: "single user message",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
},
wantLen: 1, // u:<hash>
},
{
name: "user and model messages",
req: &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi there"}}},
},
},
wantLen: 2, // u:<hash>-m:<hash>
},
{
name: "with system instruction",
req: &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Role: "user",
Parts: []antigravity.GeminiPart{{Text: "You are a helpful assistant"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
},
wantLen: 2, // s:<hash>-u:<hash>
},
{
name: "conversation with system",
req: &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Role: "user",
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "hi"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "how are you?"}}},
},
},
wantLen: 4, // s:<hash>-u:<hash>-m:<hash>-u:<hash>
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildGeminiDigestChain(tt.req)
if tt.hasEmpty {
if result != "" {
t.Errorf("expected empty string, got: %s", result)
}
return
}
// 检查分段数量
parts := splitChain(result)
if len(parts) != tt.wantLen {
t.Errorf("expected %d parts, got %d: %s", tt.wantLen, len(parts), result)
}
// 验证每个分段的格式
for _, part := range parts {
if len(part) < 3 || part[1] != ':' {
t.Errorf("invalid part format: %s", part)
}
prefix := part[0]
if prefix != 's' && prefix != 'u' && prefix != 'm' {
t.Errorf("invalid prefix: %c", prefix)
}
}
})
}
}
func TestGenerateGeminiPrefixHash(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
hash3 := GenerateGeminiPrefixHash(2, 100, "192.168.1.1", "Mozilla/5.0", "antigravity", "gemini-2.5-pro")
// 相同输入应该产生相同输出
if hash1 != hash2 {
t.Errorf("GenerateGeminiPrefixHash not deterministic: %s vs %s", hash1, hash2)
}
// 不同输入应该产生不同输出
if hash1 == hash3 {
t.Errorf("GenerateGeminiPrefixHash collision for different inputs")
}
// Base64 URL 编码的 12 字节正好是 16 字符
if len(hash1) != 16 {
t.Errorf("expected 16 characters, got %d: %s", len(hash1), hash1)
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
value string
wantUUID string
wantAccID int64
wantOK bool
}{
{
name: "empty",
value: "",
wantOK: false,
},
{
name: "no colon",
value: "abc123",
wantOK: false,
},
{
name: "valid",
value: "uuid-1234:100",
wantUUID: "uuid-1234",
wantAccID: 100,
wantOK: true,
},
{
name: "uuid with colon",
value: "a:b:c:123",
wantUUID: "a:b:c",
wantAccID: 123,
wantOK: true,
},
{
name: "invalid account id",
value: "uuid:abc",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uuid, accID, ok := ParseGeminiSessionValue(tt.value)
if ok != tt.wantOK {
t.Errorf("ok: expected %v, got %v", tt.wantOK, ok)
}
if tt.wantOK {
if uuid != tt.wantUUID {
t.Errorf("uuid: expected %s, got %s", tt.wantUUID, uuid)
}
if accID != tt.wantAccID {
t.Errorf("accountID: expected %d, got %d", tt.wantAccID, accID)
}
}
})
}
}
func TestFormatGeminiSessionValue(t *testing.T) {
result := FormatGeminiSessionValue("test-uuid", 123)
expected := "test-uuid:123"
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
// 验证往返一致性
uuid, accID, ok := ParseGeminiSessionValue(result)
if !ok {
t.Error("ParseGeminiSessionValue failed on formatted value")
}
if uuid != "test-uuid" || accID != 123 {
t.Errorf("round-trip failed: uuid=%s, accID=%d", uuid, accID)
}
}
// splitChain 辅助函数:按 "-" 分割摘要链
func splitChain(chain string) []string {
if chain == "" {
return nil
}
var parts []string
start := 0
for i := 0; i < len(chain); i++ {
if chain[i] == '-' {
parts = append(parts, chain[start:i])
start = i + 1
}
}
if start < len(chain) {
parts = append(parts, chain[start:])
}
return parts
}
func TestDigestChainDifferentSysInstruction(t *testing.T) {
req1 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "SYS_ORIGINAL"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
}
req2 := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "SYS_MODIFIED"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
if chain1 == chain2 {
t.Error("Different systemInstruction should produce different chains")
}
}
func TestDigestChainTamperedMiddleContent(t *testing.T) {
req1 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "ORIGINAL_REPLY"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}},
},
}
req2 := &antigravity.GeminiRequest{
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "hello"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "TAMPERED_REPLY"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "next"}}},
},
}
chain1 := BuildGeminiDigestChain(req1)
chain2 := BuildGeminiDigestChain(req2)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
if chain1 == chain2 {
t.Error("Tampered middle content should produce different chains")
}
// 验证第一个 user 的 hash 相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
if parts1[1] == parts2[1] {
t.Error("Model reply hash should be different")
}
}
func TestGenerateGeminiDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "gemini:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars prefix and uuid",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "gemini:digest:12345678:abcdefgh",
},
{
name: "short hash and short uuid (less than 8)",
prefixHash: "abc",
uuid: "xyz",
want: "gemini:digest:abc:xyz",
},
{
name: "empty hash and uuid",
prefixHash: "",
uuid: "",
want: "gemini:digest::",
},
{
name: "normal prefix with short uuid",
prefixHash: "abcdefgh12345678",
uuid: "short",
want: "gemini:digest:abcdefgh:short",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateGeminiDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateGeminiDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证确定性:相同输入产生相同输出
t.Run("deterministic", func(t *testing.T) {
hash := "testprefix123456"
uuid := "test-uuid-12345"
result1 := GenerateGeminiDigestSessionKey(hash, uuid)
result2 := GenerateGeminiDigestSessionKey(hash, uuid)
if result1 != result2 {
t.Errorf("GenerateGeminiDigestSessionKey not deterministic: %s vs %s", result1, result2)
}
})
// 验证不同 uuid 产生不同 sessionKey(负载均衡核心逻辑)
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
uuid1 := "uuid0001-session-a"
uuid2 := "uuid0002-session-b"
result1 := GenerateGeminiDigestSessionKey(hash, uuid1)
result2 := GenerateGeminiDigestSessionKey(hash, uuid2)
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
package service
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
const modelRateLimitsKey = "model_rate_limits"
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
model := strings.ToLower(strings.TrimSpace(requestedModel))
if model == "" {
return "", false
// isRateLimitActiveForKey 检查指定 key 的限流是否生效
func (a *Account) isRateLimitActiveForKey(key string) bool {
resetAt := a.modelRateLimitResetAt(key)
return resetAt != nil && time.Now().Before(*resetAt)
}
// getRateLimitRemainingForKey 获取指定 key 的限流剩余时间,0 表示未限流或已过期
func (a *Account) getRateLimitRemainingForKey(key string) time.Duration {
resetAt := a.modelRateLimitResetAt(key)
if resetAt == nil {
return 0
}
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "sonnet") {
return modelRateLimitScopeClaudeSonnet, true
remaining := time.Until(*resetAt)
if remaining > 0 {
return remaining
}
return "", false
return 0
}
func (a *Account) isModelRateLimited(requestedModel string) bool {
scope, ok := resolveModelRateLimitScope(requestedModel)
if !ok {
func (a *Account) isModelRateLimitedWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
resetAt := a.modelRateLimitResetAt(scope)
if resetAt == nil {
modelKey := a.GetMappedModel(requestedModel)
if a.Platform == PlatformAntigravity {
modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel)
}
modelKey = strings.TrimSpace(modelKey)
if modelKey == "" {
return false
}
return time.Now().Before(*resetAt)
return a.isRateLimitActiveForKey(modelKey)
}
// GetModelRateLimitRemainingTime 获取模型限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetModelRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetModelRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
func (a *Account) GetModelRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelKey := a.GetMappedModel(requestedModel)
if a.Platform == PlatformAntigravity {
modelKey = resolveFinalAntigravityModelKey(ctx, a, requestedModel)
}
modelKey = strings.TrimSpace(modelKey)
if modelKey == "" {
return 0
}
return a.getRateLimitRemainingForKey(modelKey)
}
func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requestedModel string) string {
modelKey := mapAntigravityModel(account, requestedModel)
if modelKey == "" {
return ""
}
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking)
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
modelKey = applyThinkingModelSuffix(modelKey, enabled)
}
return modelKey
}
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
......
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
func TestIsModelRateLimited(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
expected bool
}{
{
name: "official model ID hit - claude-sonnet-4-5",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
expected: true,
},
{
name: "official model ID hit via mapping - request claude-3-5-sonnet, mapped to claude-sonnet-4-5",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "claude-sonnet-4-5",
},
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-3-5-sonnet",
expected: true,
},
{
name: "no rate limit - expired",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
expected: false,
},
{
name: "no rate limit - no matching key",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-flash": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
expected: false,
},
{
name: "no rate limit - unsupported model",
account: &Account{},
requestedModel: "gpt-4",
expected: false,
},
{
name: "no rate limit - empty model",
account: &Account{},
requestedModel: "",
expected: false,
},
{
name: "gemini model hit",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "gemini-3-pro-high",
expected: true,
},
{
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "gemini-3-pro-preview",
expected: true,
},
{
name: "non-antigravity platform - gemini-3-pro-preview NOT mapped",
account: &Account{
Platform: PlatformGemini,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "gemini-3-pro-preview",
expected: false, // gemini 平台不走 antigravity 映射
},
{
name: "antigravity platform - claude-opus-4-5-thinking mapped to opus-4-6-thinking",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-opus-4-5-thinking",
expected: true,
},
{
name: "no scope fallback - claude_sonnet should not match",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude_sonnet": map[string]any{
"rate_limit_reset_at": future,
},
},
},
},
requestedModel: "claude-3-5-sonnet-20241022",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.isModelRateLimitedWithContext(context.Background(), tt.requestedModel)
if result != tt.expected {
t.Errorf("isModelRateLimited(%q) = %v, want %v", tt.requestedModel, result, tt.expected)
}
})
}
}
func TestIsModelRateLimited_Antigravity_ThinkingAffectsModelKey(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute).Format(time.RFC3339)
account := &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5-thinking": map[string]any{
"rate_limit_reset_at": future,
},
},
},
}
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
if !account.isModelRateLimitedWithContext(ctx, "claude-sonnet-4-5") {
t.Errorf("expected model to be rate limited")
}
}
func TestGetModelRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
future5m := now.Add(5 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "model rate limited - direct hit",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "model rate limited - via mapping",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "claude-sonnet-4-5",
},
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-3-5-sonnet",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "expired rate limit",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "no rate limit data",
account: &Account{},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "no scope fallback",
account: &Account{
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude_sonnet": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-3-5-sonnet-20241022",
minExpected: 0,
maxExpected: 0,
},
{
name: "antigravity platform - claude-opus-4-5-thinking mapped to opus-4-6-thinking",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-opus-4-5-thinking",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetModelRateLimitRemainingTimeWithContext(context.Background(), tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetModelRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "non-antigravity platform",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "claude scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "gemini_text scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"gemini_text": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "gemini-3-flash",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "expired scope rate limit",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "unsupported model",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "gpt-4",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
future5m := now.Add(5 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "model remaining > scope remaining - returns model",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "scope remaining > model remaining - returns scope",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "only model rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "only scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "neither rate limited",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetRateLimitRemainingTimeWithContext(context.Background(), tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
......@@ -2,19 +2,7 @@ package service
import (
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
//go:embed prompts/codex_cli_instructions.md
......@@ -77,12 +65,6 @@ type codexTransformResult struct {
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
......@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string {
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
// 若 opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
return getCodexCLIInstructions()
}
......@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string {
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令)
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if isCodexCLI {
return applyCodexCLIInstructions(reqBody)
......@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
func applyCodexCLIInstructions(reqBody map[string]any) bool {
if !isInstructionsEmpty(reqBody) {
return false // 已有有效 instructions,不修改
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
instructions := strings.TrimSpace(getCodexCLIInstructions())
if instructions != "" {
reqBody["instructions"] = instructions
return true
......@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
return false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
// 优先使用内置 Codex CLI 指令覆盖
func applyOpenCodeInstructions(reqBody map[string]any) bool {
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
......@@ -346,47 +283,6 @@ func isInstructionsEmpty(reqBody map[string]any) bool {
return strings.TrimSpace(str) == ""
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions == "" {
return false
}
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) != codexInstructions {
reqBody["instructions"] = codexInstructions
return true
}
return false
}
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
func IsInstructionError(errorMessage string) bool {
if errorMessage == "" {
return false
}
lowerMsg := strings.ToLower(errorMessage)
instructionKeywords := []string{
"instruction",
"instructions",
"system prompt",
"system message",
"invalid prompt",
"prompt format",
}
for _, keyword := range instructionKeywords {
if strings.Contains(lowerMsg, keyword) {
return true
}
}
return false
}
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any {
......@@ -530,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
return modified
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
......@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
// 显式 store=true 也会强制为 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
}
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"tools": []any{
......@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -189,30 +178,8 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
}
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
}
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时不修改
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -230,7 +197,6 @@ func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *test
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充默认值
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
......@@ -246,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
}
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令覆盖
setupCodexCache(t)
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
reqBody := map[string]any{
"model": "gpt-5.1",
......
......@@ -24,6 +24,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
......@@ -332,7 +334,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
if shouldClearStickySession(account, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
......@@ -498,7 +500,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
}
......@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified := false
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel)
......@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{
......@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("user-agent", customUA)
}
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
}
// Ensure required headers exist
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
......@@ -1087,6 +1099,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
)
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformOpenAI,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
......@@ -1209,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct {
line string
......@@ -1228,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -1239,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
......@@ -1418,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// Replace model in response
if m, ok := event["model"].(string); ok && m == fromModel {
event["model"] = toModel
newData, err := json.Marshal(event)
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
newData, err := sjson.Set(data, "model", toModel)
if err != nil {
return line
}
return "data: " + string(newData)
return "data: " + newData
}
// Check nested response
if response, ok := event["response"].(map[string]any); ok {
if m, ok := response["model"].(string); ok && m == fromModel {
response["model"] = toModel
newData, err := json.Marshal(event)
// 检查嵌套的 response.model 字段
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
newData, err := sjson.Set(data, "response.model", toModel)
if err != nil {
return line
}
return "data: " + string(newData)
}
return "data: " + newData
}
return line
......@@ -1662,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
newBody, err := sjson.SetBytes(body, "model", toModel)
if err != nil {
return body
}
return newBody
}
return body
}
// OpenAIRecordUsageInput input for recording usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment