Commit a14dfb76 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'dev-release'

parents f3605ddc 2588fa6a
package openai
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
......@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
......
......@@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return keys, nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldQuotaUsed).
Only(ctx)
updated, err := r.client.APIKey.UpdateOneID(id).
Where(apikey.DeletedAtIsNil()).
AddQuotaUsed(amount).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
newValue := m.QuotaUsed + amount
// Update with new value
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetQuotaUsed(newValue).
Save(ctx)
if err != nil {
return 0, err
}
if affected == 0 {
return 0, service.ErrAPIKeyNotFound
}
return newValue, nil
return updated.QuotaUsed, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
......
......@@ -4,11 +4,14 @@ package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
......@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}
// --- IncrementQuotaUsed ---
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
user := s.mustCreateUser("incr-basic@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
s.Require().NoError(err, "IncrementQuotaUsed")
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsed second")
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
user := s.mustCreateUser("incr-deleted@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
client := testEntClient(t)
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
ctx := context.Background()
// 创建测试用户和 API Key
u, err := client.User.Create().
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
SetPasswordHash("hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(ctx)
require.NoError(t, err, "create user")
k := &service.APIKey{
UserID: u.ID,
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
Name: "Concurrent",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, k), "create api key")
t.Cleanup(func() {
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
})
// 10 个 goroutine 各递增 1.0,总计应为 10.0
const goroutines = 10
const increment = 1.0
var wg sync.WaitGroup
errs := make([]error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
}(i)
}
wg.Wait()
for i, e := range errs {
require.NoError(t, e, "goroutine %d failed", i)
}
// 验证最终结果
got, err := repo.GetByID(ctx, k.ID)
require.NoError(t, err, "GetByID")
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
}
......@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand"
"strconv"
"time"
......@@ -16,8 +17,15 @@ const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
)
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
func jitteredTTL() time.Duration {
jitter := time.Duration(rand.Int63n(int64(2*billingCacheJitter))) - billingCacheJitter
return billingCacheTTL + jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
......@@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
......@@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
pipe.Expire(ctx, key, jitteredTTL())
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
......
......@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}
}
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
tests := []struct {
name string
fn func(ctx context.Context, cache service.BillingCache)
expectErr bool
}{
{
name: "key_not_exists_returns_nil",
fn: func(ctx context.Context, cache service.BillingCache) {
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
err := cache.DeductUserBalance(ctx, 99999, 1.0)
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
},
},
{
name: "existing_key_deducts_successfully",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
err := cache.DeductUserBalance(ctx, 200, 10.0)
require.NoError(s.T(), err, "DeductUserBalance should succeed")
bal, err := cache.GetUserBalance(ctx, 200)
require.NoError(s.T(), err)
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
},
},
{
name: "cancelled_context_propagates_error",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
cancelCtx, cancel := context.WithCancel(ctx)
cancel() // 立即取消
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
require.Error(s.T(), err, "cancelled context should propagate error")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, cache)
})
}
}
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
s.Run("key_not_exists_returns_nil", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
})
s.Run("cancelled_context_propagates_error", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
require.Error(s.T(), err, "cancelled context should propagate error")
})
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}
......@@ -5,6 +5,7 @@ package repository
import (
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
)
......@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
})
}
}
func TestJitteredTTL(t *testing.T) {
const (
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
)
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
}
}
func TestJitteredTTL_HasVariation(t *testing.T) {
// 多次调用应该产生不同的值(验证抖动存在)
seen := make(map[time.Duration]struct{}, 50)
for i := 0; i < 50; i++ {
seen[jitteredTTL()] = struct{}{}
}
// 50 次调用中应该至少有 2 个不同的值
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
}
......@@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
......
......@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
......@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
......
......@@ -24,6 +24,22 @@ import (
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
"day": "YYYY-MM-DD",
"week": "IYYY-IW",
"month": "YYYY-MM",
}
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
func safeDateFormat(granularity string) string {
if f, ok := dateFormatWhitelist[granularity]; ok {
return f
}
return "YYYY-MM-DD"
}
type usageLogRepository struct {
client *dbent.Client
sql sqlExecutor
......@@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
return logs, nil, err
}
......@@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string {
}
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err
}
......@@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
WITH top_keys AS (
......@@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
// GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
WITH top_users AS (
......@@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
SELECT
......@@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 {
return result, nil
}
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range userIDs {
result[id] = &BatchUserUsageStats{UserID: id}
}
......@@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE user_id = ANY($1)
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
if err != nil {
return nil, err
}
......@@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
......@@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE api_key_id = ANY($1)
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY api_key_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
if err != nil {
return nil, err
}
......@@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
SELECT
......
......@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID])
......@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
}
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
......@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
......
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSafeDateFormat(t *testing.T) {
tests := []struct {
name string
granularity string
expected string
}{
// 合法值
{"hour", "hour", "YYYY-MM-DD HH24:00"},
{"day", "day", "YYYY-MM-DD"},
{"week", "week", "IYYY-IW"},
{"month", "month", "YYYY-MM"},
// 非法值回退到默认
{"空字符串", "", "YYYY-MM-DD"},
{"未知粒度 year", "year", "YYYY-MM-DD"},
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
// 恶意字符串
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
{"带引号", "day'", "YYYY-MM-DD"},
{"带括号", "day)", "YYYY-MM-DD"},
{"Unicode", "日", "YYYY-MM-DD"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := safeDateFormat(tc.granularity)
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
})
}
}
......@@ -592,13 +592,13 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo, nil)
userService := service.NewUserService(userRepo, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, cfg)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
......@@ -1602,11 +1602,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
......
......@@ -176,6 +176,12 @@ func validateJWTForAdmin(
return false
}
// 校验 TokenVersion,确保管理员改密后旧 token 失效
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return false
}
// 检查管理员权限
if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
......
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
Email: "admin@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
TokenVersion: 2,
Concurrency: 1,
}
userRepo := &stubUserRepo{
getByID: func(ctx context.Context, id int64) (*service.User, error) {
if id != admin.ID {
return nil, service.ErrUserNotFound
}
clone := *admin
return &clone, nil
},
}
userService := service.NewUserService(userRepo, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
}
type stubUserRepo struct {
getByID func(ctx context.Context, id int64) (*service.User, error)
}
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
panic("unexpected Create call")
}
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
if s.getByID == nil {
panic("GetByID not stubbed")
}
return s.getByID(ctx, id)
}
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
panic("unexpected GetByEmail call")
}
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
panic("unexpected Update call")
}
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
......@@ -3,7 +3,6 @@ package middleware
import (
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
// 订阅模式:获取订阅(L1 缓存 + singleflight)
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
......@@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
// 合并验证 + 限额检查(纯内存操作)
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
// 预检查用量限制(使用0作为额外费用进行预检查)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
AbortWithError(c, status, code, err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
......
......@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
......@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
......
......@@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
// 处理预检请求
if c.Request.Method == http.MethodOptions {
......
......@@ -36,8 +36,8 @@ type UsageLogRepository interface {
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
......
......@@ -1582,6 +1582,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return changed, nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
usage := &ClaudeUsage{}
var firstTokenMs *int
var firstTokenRecorded bool
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
// 记录首 token 时间
if !firstTokenRecorded && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
firstTokenRecorded = true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if bytes.HasPrefix(line, []byte("data: ")) {
dataStr := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]any
if json.Unmarshal(dataStr, &event) == nil {
if u, ok := event["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
}
}
// 透传行
_, _ = c.Writer.Write(line)
_, _ = c.Writer.Write([]byte("\n"))
c.Writer.Flush()
}
return usage, firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
......@@ -1613,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(time.Now()),
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
default:
......@@ -2288,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
......@@ -2309,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2320,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -2445,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
......@@ -2473,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2484,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -2888,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
var firstTokenMs *int
var last map[string]any
......@@ -2914,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2925,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -3068,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
......@@ -3100,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -3111,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment