"frontend/vscode:/vscode.git/clone" did not exist on "aa1a3b9a74a505eb4362a346dfc9026dae8629d3"
Commit fd43be8d authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突



- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parents 792bef61 836ba14b
...@@ -5,6 +5,7 @@ package repository ...@@ -5,6 +5,7 @@ package repository
import ( import (
"math" "math"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) { ...@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
}) })
} }
} }
func TestJitteredTTL(t *testing.T) {
const (
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
)
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
}
}
func TestJitteredTTL_HasVariation(t *testing.T) {
// 多次调用应该产生不同的值(验证抖动存在)
seen := make(map[time.Duration]struct{}, 50)
for i := 0; i < 50; i++ {
seen[jitteredTTL()] = struct{}{}
}
// 50 次调用中应该至少有 2 个不同的值
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
}
...@@ -194,6 +194,53 @@ var ( ...@@ -194,6 +194,53 @@ var (
return result return result
`) `)
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
getUsersLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
local userID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:user:' .. userID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'concurrency:wait:' .. userID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, userID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots // cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID} // KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds) // ARGV[1] = TTL (seconds)
...@@ -384,6 +431,43 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts [] ...@@ -384,6 +431,43 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
return loadMap, nil return loadMap, nil
} }
func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
if len(users) == 0 {
return map[int64]*service.UserLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, u := range users {
args = append(args, u.ID, u.MaxConcurrency)
}
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.UserLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[userID] = &service.UserLoadInfo{
UserID: userID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID) key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
......
...@@ -11,6 +11,63 @@ import ( ...@@ -11,6 +11,63 @@ import (
const stickySessionPrefix = "sticky_session:" const stickySessionPrefix = "sticky_session:"
// Gemini Trie Lua 脚本
const (
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
// ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
geminiTrieFindScript = `
local chain = ARGV[1]
local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
local val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
lastMatch = val
end
end
if lastMatch then
redis.call('EXPIRE', KEYS[1], ttl)
end
return lastMatch
`
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain
// ARGV[2] = value (uuid:accountID)
// ARGV[3] = TTL seconds
geminiTrieSaveScript = `
local chain = ARGV[1]
local value = ARGV[2]
local ttl = tonumber(ARGV[3])
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
end
redis.call('HSET', KEYS[1], path, value)
redis.call('EXPIRE', KEYS[1], ttl)
return "OK"
`
)
// 模型负载统计相关常量
const (
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零)
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
)
type gatewayCache struct { type gatewayCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -51,3 +108,133 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 ...@@ -51,3 +108,133 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key := buildSessionKey(groupID, sessionHash) key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
// ============ Antigravity 模型负载统计方法 ============
// modelLoadKey 构建模型调用次数 key
// 格式: ag:model_load:{accountID}:{model}
func modelLoadKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
}
// modelLastUsedKey 构建模型最后调度时间 key
// 格式: ag:model_last_used:{accountID}:{model}
func modelLastUsedKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
}
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
// 返回更新后的调用次数
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
loadKey := modelLoadKey(accountID, model)
lastUsedKey := modelLastUsedKey(accountID, model)
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, loadKey)
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, err
}
return incrCmd.Val(), nil
}
// GetModelLoadBatch 批量获取账号的模型负载信息
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
if len(accountIDs) == 0 {
return make(map[int64]*service.ModelLoadInfo), nil
}
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
}
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
func (c *gatewayCache) pipelineModelLoadGet(
ctx context.Context,
accountIDs []int64,
model string,
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
pipe := c.rdb.Pipeline()
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
}
_, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的
return loadCmds, lastUsedCmds
}
// parseModelLoadResults 解析 Pipeline 结果
func (c *gatewayCache) parseModelLoadResults(
accountIDs []int64,
loadCmds map[int64]*redis.StringCmd,
lastUsedCmds map[int64]*redis.StringCmd,
) map[int64]*service.ModelLoadInfo {
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
for _, id := range accountIDs {
result[id] = &service.ModelLoadInfo{
CallCount: getInt64OrZero(loadCmds[id]),
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
}
}
return result
}
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
func getInt64OrZero(cmd *redis.StringCmd) int64 {
val, _ := cmd.Int64()
return val
}
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
val, err := cmd.Int64()
if err != nil {
return time.Time{}
}
return time.Unix(val, 0)
}
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}
...@@ -104,6 +104,158 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { ...@@ -104,6 +104,158 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
} }
// ============ Gemini Trie 会话测试 ============
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
groupID := int64(1)
prefixHash := "testprefix"
digestChain := "u:hash1-m:hash2-u:hash3"
uuid := "test-uuid-123"
accountID := int64(42)
// 保存会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
require.NoError(s.T(), err, "SaveGeminiSession")
// 精确匹配查找
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
require.True(s.T(), found, "should find exact match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
groupID := int64(1)
prefixHash := "prefixmatch"
shortChain := "u:a-m:b"
longChain := "u:a-m:b-u:c-m:d"
uuid := "uuid-prefix"
accountID := int64(100)
// 保存短链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
require.NoError(s.T(), err)
// 用长链查找,应该匹配到短链(前缀匹配)
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
require.True(s.T(), found, "should find prefix match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
groupID := int64(1)
prefixHash := "longestmatch"
// 保存多个不同长度的链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
require.NoError(s.T(), err)
// 查找更长的链,应该匹配到最长的前缀
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
require.True(s.T(), found, "should find longest prefix match")
require.Equal(s.T(), "uuid-long", foundUUID)
require.Equal(s.T(), int64(3), foundAccountID)
// 查找中等长度的链
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-medium", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
groupID := int64(1)
prefixHash := "nomatch"
digestChain := "u:a-m:b"
// 保存一个会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
require.NoError(s.T(), err)
// 用不同的链查找,应该找不到
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
require.False(s.T(), found, "should not find non-matching chain")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
groupID := int64(1)
digestChain := "u:a-m:b"
// 保存到 prefixHash1
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
require.False(s.T(), found, "different prefixHash should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
prefixHash := "sameprefix"
digestChain := "u:a-m:b"
// 保存到 groupID 1
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 groupID 2 查找,应该找不到(分组隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
require.False(s.T(), found, "different groupID should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
groupID := int64(1)
prefixHash := "emptytest"
// 空链不应该保存
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
require.NoError(s.T(), err, "empty chain should not error")
// 空链查找应该返回 false
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
require.False(s.T(), found, "empty chain should not match")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
groupID := int64(1)
prefixHash := "multisession"
// 保存多个不同会话(模拟 1000 个并发会话的场景)
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
require.NoError(s.T(), err)
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
require.True(s.T(), found, "should find session: %s", sess.chain)
require.Equal(s.T(), sess.uuid, foundUUID)
require.Equal(s.T(), sess.accountID, foundAccountID)
}
// 验证继续对话的场景
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-2", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func TestGatewayCacheSuite(t *testing.T) { func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite)) suite.Run(t, new(GatewayCacheSuite))
} }
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ============ Gateway Cache 模型负载统计集成测试 ============
type GatewayCacheModelLoadSuite struct {
suite.Suite
}
func TestGatewayCacheModelLoadSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheModelLoadSuite))
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(123)
model := "claude-sonnet-4-20250514"
// 首次调用应返回 1
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
// 第二次调用应返回 2
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(2), count2)
// 第三次调用应返回 3
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(3), count3)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(456)
model1 := "claude-sonnet-4-20250514"
model2 := "claude-opus-4-5-20251101"
// 不同模型应该独立计数
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(2), count1Again)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
account1 := int64(111)
account2 := int64(222)
model := "gemini-2.5-pro"
// 不同账号应该独立计数
count1, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
require.NoError(t, err)
require.NotNil(t, result)
require.Empty(t, result)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
// 查询不存在的账号应返回零值
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
require.NoError(t, err)
require.Len(t, result, 2)
require.Equal(t, int64(0), result[9999].CallCount)
require.True(t, result[9999].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[9998].CallCount)
require.True(t, result[9998].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(789)
model := "claude-sonnet-4-20250514"
// 先增加调用次数
beforeIncr := time.Now()
_, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
afterIncr := time.Now()
// 获取负载信息
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
require.NoError(t, err)
require.Len(t, result, 1)
loadInfo := result[accountID]
require.NotNil(t, loadInfo)
require.Equal(t, int64(3), loadInfo.CallCount)
require.False(t, loadInfo.LastUsedAt.IsZero())
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
model := "claude-opus-4-5-20251101"
account1 := int64(1001)
account2 := int64(1002)
account3 := int64(1003) // 不调用
// account1 调用 2 次
_, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
// account2 调用 5 次
for i := 0; i < 5; i++ {
_, err = cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
}
// 批量获取
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
require.NoError(t, err)
require.Len(t, result, 3)
require.Equal(t, int64(2), result[account1].CallCount)
require.False(t, result[account1].LastUsedAt.IsZero())
require.Equal(t, int64(5), result[account2].CallCount)
require.False(t, result[account2].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[account3].CallCount)
require.True(t, result[account3].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(2001)
model1 := "claude-sonnet-4-20250514"
model2 := "gemini-2.5-pro"
// 对 model1 调用 3 次
for i := 0; i < 3; i++ {
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
}
// 获取 model1 的负载
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
require.NoError(t, err)
require.Equal(t, int64(3), result1[accountID].CallCount)
// 获取 model2 的负载(应该为 0)
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
require.NoError(t, err)
require.Equal(t, int64(0), result2[accountID].CallCount)
}
// ============ 辅助函数测试 ============
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
t := s.T()
key := modelLoadKey(123, "claude-sonnet-4")
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
}
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
t := s.T()
key := modelLastUsedKey(456, "gemini-2.5-pro")
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
}
...@@ -98,12 +98,16 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -98,12 +98,16 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
if err != nil { if err != nil {
return err return err
} }
defer func() { _ = out.Close() }()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong // SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1) limited := io.LimitReader(resp.Body, maxSize+1)
written, err := io.Copy(out, limited) written, err := io.Copy(out, limited)
// Close file before attempting to remove (required on Windows)
_ = out.Close()
if err != nil { if err != nil {
_ = os.Remove(dest) // Clean up partial file (best-effort)
return err return err
} }
......
...@@ -191,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -191,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
q = q.Where(group.IsExclusiveEQ(*isExclusive)) q = q.Where(group.IsExclusiveEQ(*isExclusive))
} }
total, err := q.Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
......
...@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina ...@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
q = q.Where(promocode.CodeContainsFold(search)) q = q.Where(promocode.CodeContainsFold(search))
} }
total, err := q.Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo ...@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
q := r.client.PromoCodeUsage.Query(). q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID)) Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
......
...@@ -60,6 +60,25 @@ func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy ...@@ -60,6 +60,25 @@ func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy
return proxyEntityToService(m), nil return proxyEntityToService(m), nil
} }
func (r *proxyRepository) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
proxies, err := r.client.Proxy.Query().
Where(proxy.IDIn(ids...)).
All(ctx)
if err != nil {
return nil, err
}
out := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *proxyEntityToService(proxies[i]))
}
return out, nil
}
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.UpdateOneID(proxyIn.ID). builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
SetName(proxyIn.Name). SetName(proxyIn.Name).
......
...@@ -24,6 +24,22 @@ import ( ...@@ -24,6 +24,22 @@ import (
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
"day": "YYYY-MM-DD",
"week": "IYYY-IW",
"month": "YYYY-MM",
}
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
func safeDateFormat(granularity string) string {
if f, ok := dateFormatWhitelist[granularity]; ok {
return f
}
return "YYYY-MM-DD"
}
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor sql sqlExecutor
...@@ -567,7 +583,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, ...@@ -567,7 +583,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
} }
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
return logs, nil, err return logs, nil, err
} }
...@@ -813,19 +829,19 @@ func resolveUsageStatsTimezone() string { ...@@ -813,19 +829,19 @@ func resolveUsageStatsTimezone() string {
} }
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err return logs, nil, err
} }
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
return logs, nil, err return logs, nil, err
} }
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err return logs, nil, err
} }
...@@ -911,10 +927,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint ...@@ -911,10 +927,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := safeDateFormat(granularity)
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(` query := fmt.Sprintf(`
WITH top_keys AS ( WITH top_keys AS (
...@@ -969,10 +982,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, ...@@ -969,10 +982,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
// GetUserUsageTrend returns usage trend data grouped by user and date // GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := safeDateFormat(granularity)
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(` query := fmt.Sprintf(`
WITH top_users AS ( WITH top_users AS (
...@@ -1231,10 +1241,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey ...@@ -1231,10 +1241,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := safeDateFormat(granularity)
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(` query := fmt.Sprintf(`
SELECT SELECT
...@@ -1372,13 +1379,22 @@ type UsageStats = usagestats.UsageStats ...@@ -1372,13 +1379,22 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user // BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users // GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { // If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats) result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 { if len(userIDs) == 0 {
return result, nil return result, nil
} }
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range userIDs { for _, id := range userIDs {
result[id] = &BatchUserUsageStats{UserID: id} result[id] = &BatchUserUsageStats{UserID: id}
} }
...@@ -1386,10 +1402,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -1386,10 +1402,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
query := ` query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs FROM usage_logs
WHERE user_id = ANY($1) WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY user_id GROUP BY user_id
` `
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1446,13 +1462,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -1446,13 +1462,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
// BatchAPIKeyUsageStats represents usage stats for a single API key // BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { // If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats) result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return result, nil return result, nil
} }
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range apiKeyIDs { for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
} }
...@@ -1460,10 +1485,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe ...@@ -1460,10 +1485,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
query := ` query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs FROM usage_logs
WHERE api_key_id = ANY($1) WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY api_key_id GROUP BY api_key_id
` `
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs)) rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1519,10 +1544,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe ...@@ -1519,10 +1544,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters // GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := safeDateFormat(granularity)
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(` query := fmt.Sprintf(`
SELECT SELECT
......
...@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { ...@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchUserUsageStats") s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID]) s.Require().NotNil(stats[user1.ID])
...@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { ...@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
} }
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Empty(stats) s.Require().Empty(stats)
} }
...@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { ...@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Empty(stats) s.Require().Empty(stats)
} }
......
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSafeDateFormat(t *testing.T) {
tests := []struct {
name string
granularity string
expected string
}{
// 合法值
{"hour", "hour", "YYYY-MM-DD HH24:00"},
{"day", "day", "YYYY-MM-DD"},
{"week", "week", "IYYY-IW"},
{"month", "month", "YYYY-MM"},
// 非法值回退到默认
{"空字符串", "", "YYYY-MM-DD"},
{"未知粒度 year", "year", "YYYY-MM-DD"},
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
// 恶意字符串
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
{"带引号", "day'", "YYYY-MM-DD"},
{"带括号", "day)", "YYYY-MM-DD"},
{"Unicode", "日", "YYYY-MM-DD"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := safeDateFormat(tc.granularity)
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
})
}
}
...@@ -597,13 +597,13 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -597,13 +597,13 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo, nil) userService := service.NewUserService(userRepo, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil) subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, cfg)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
...@@ -1068,6 +1068,10 @@ func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, err ...@@ -1068,6 +1068,10 @@ func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, err
return nil, service.ErrProxyNotFound return nil, service.ErrProxyNotFound
} }
func (stubProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
return nil, errors.New("not implemented")
}
func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error { func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -1607,11 +1611,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i ...@@ -1607,11 +1611,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -176,6 +176,12 @@ func validateJWTForAdmin( ...@@ -176,6 +176,12 @@ func validateJWTForAdmin(
return false return false
} }
// 校验 TokenVersion,确保管理员改密后旧 token 失效
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return false
}
// 检查管理员权限 // 检查管理员权限
if !user.IsAdmin() { if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
......
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
Email: "admin@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
TokenVersion: 2,
Concurrency: 1,
}
userRepo := &stubUserRepo{
getByID: func(ctx context.Context, id int64) (*service.User, error) {
if id != admin.ID {
return nil, service.ErrUserNotFound
}
clone := *admin
return &clone, nil
},
}
userService := service.NewUserService(userRepo, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
}
type stubUserRepo struct {
getByID func(ctx context.Context, id int64) (*service.User, error)
}
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
panic("unexpected Create call")
}
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
if s.getByID == nil {
panic("GetByID not stubbed")
}
return s.getByID(ctx, id)
}
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
panic("unexpected GetByEmail call")
}
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
panic("unexpected Update call")
}
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
...@@ -3,7 +3,6 @@ package middleware ...@@ -3,7 +3,6 @@ package middleware
import ( import (
"context" "context"
"errors" "errors"
"log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
...@@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil { if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅 // 订阅模式:获取订阅(L1 缓存 + singleflight)
subscription, err := subscriptionService.GetActiveSubscription( subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(), c.Request.Context(),
apiKey.User.ID, apiKey.User.ID,
...@@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// 验证订阅状态(是否过期、暂停等) // 合并验证 + 限额检查(纯内存操作)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) if err != nil {
return code := "SUBSCRIPTION_INVALID"
} status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
// 激活滑动窗口(首次使用时) errors.Is(err, service.ErrWeeklyLimitExceeded) ||
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil { errors.Is(err, service.ErrMonthlyLimitExceeded) {
log.Printf("Failed to activate subscription windows: %v", err) code = "USAGE_LIMIT_EXCEEDED"
} status = 429
}
// 检查并重置过期窗口 AbortWithError(c, status, code, err.Error())
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制(使用0作为额外费用进行预检查)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return return
} }
// 将订阅信息存入上下文 // 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription) c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else { } else {
// 余额模式:检查用户余额 // 余额模式:检查用户余额
if apiKey.User.Balance <= 0 { if apiKey.User.Balance <= 0 {
......
...@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder() w := httptest.NewRecorder()
...@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
} }
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder() w := httptest.NewRecorder()
......
...@@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { ...@@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
// 处理预检请求 // 处理预检请求
if c.Request.Method == http.MethodOptions { if c.Request.Method == http.MethodOptions {
......
...@@ -78,6 +78,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -78,6 +78,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
// Realtime ops signals // Realtime ops signals
ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats) ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats)
ops.GET("/user-concurrency", h.Admin.Ops.GetUserConcurrencyStats)
ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability) ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability)
ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary) ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary)
...@@ -222,10 +223,15 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -222,10 +223,15 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.GET("/data", h.Admin.Account.ExportData)
accounts.POST("/data", h.Admin.Account.ImportData)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier) accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Antigravity 默认模型映射
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
// Claude OAuth routes // Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL) accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
...@@ -281,6 +287,8 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -281,6 +287,8 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
proxies.GET("", h.Admin.Proxy.List) proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll) proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/data", h.Admin.Proxy.ExportData)
proxies.POST("/data", h.Admin.Proxy.ImportData)
proxies.GET("/:id", h.Admin.Proxy.GetByID) proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create) proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update) proxies.PUT("/:id", h.Admin.Proxy.Update)
......
...@@ -3,9 +3,12 @@ package service ...@@ -3,9 +3,12 @@ package service
import ( import (
"encoding/json" "encoding/json"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
type Account struct { type Account struct {
...@@ -347,10 +350,18 @@ func parseTempUnschedInt(value any) int { ...@@ -347,10 +350,18 @@ func parseTempUnschedInt(value any) int {
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil return nil
} }
raw, ok := a.Credentials["model_mapping"] raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil { if !ok || raw == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil return nil
} }
if m, ok := raw.(map[string]any); ok { if m, ok := raw.(map[string]any); ok {
...@@ -364,27 +375,46 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -364,27 +375,46 @@ func (a *Account) GetModelMapping() map[string]string {
return result return result
} }
} }
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil return nil
} }
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping,返回 true(允许所有模型)
func (a *Account) IsModelSupported(requestedModel string) bool { func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
return true // 无映射 = 允许所有
}
// 精确匹配
if _, exists := mapping[requestedModel]; exists {
return true return true
} }
_, exists := mapping[requestedModel] // 通配符匹配
return exists for pattern := range mapping {
if matchWildcard(pattern, requestedModel) {
return true
}
}
return false
} }
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel return requestedModel
} }
// 精确匹配优先
if mappedModel, exists := mapping[requestedModel]; exists { if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel return mappedModel
} }
return requestedModel // 通配符匹配(最长优先)
return matchWildcardMapping(mapping, requestedModel)
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
...@@ -426,6 +456,53 @@ func (a *Account) GetClaudeUserID() string { ...@@ -426,6 +456,53 @@ func (a *Account) GetClaudeUserID() string {
return "" return ""
} }
// matchAntigravityWildcard 通配符匹配(仅支持末尾 *)
// 用于 model_mapping 的通配符匹配
func matchAntigravityWildcard(pattern, str string) bool {
if strings.HasSuffix(pattern, "*") {
prefix := pattern[:len(pattern)-1]
return strings.HasPrefix(str, prefix)
}
return pattern == str
}
// matchWildcard 通用通配符匹配(仅支持末尾 *)
// 复用 Antigravity 的通配符逻辑,供其他平台使用
func matchWildcard(pattern, str string) bool {
return matchAntigravityWildcard(pattern, str)
}
// matchWildcardMapping 通配符映射匹配(最长优先)
// 如果没有匹配,返回原始字符串
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
type patternMatch struct {
pattern string
target string
}
var matches []patternMatch
for pattern, target := range mapping {
if matchWildcard(pattern, requestedModel) {
matches = append(matches, patternMatch{pattern, target})
}
}
if len(matches) == 0 {
return requestedModel // 无匹配,返回原始模型名
}
// 按 pattern 长度降序排序
sort.Slice(matches, func(i, j int) bool {
if len(matches[i].pattern) != len(matches[j].pattern) {
return len(matches[i].pattern) > len(matches[j].pattern)
}
return matches[i].pattern < matches[j].pattern
})
return matches[0].target
}
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment