Unverified Commit 07be258d authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #603 from mt21625457/release

feat : 大幅度的性能优化 和 新增了很多功能
parents dbdb2959 53d55bb9
......@@ -315,3 +315,69 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
require.NotEmpty(t, newToken)
})
}
func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 0
require.Equal(t, 24*3600, service.GetAccessTokenExpiresIn())
}
func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 90
require.Equal(t, 90*60, service.GetAccessTokenExpiresIn())
}
func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 0
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.NotNil(t, claims.IssuedAt)
require.NotNil(t, claims.ExpiresAt)
require.WithinDuration(t, claims.IssuedAt.Time.Add(24*time.Hour), claims.ExpiresAt.Time, 2*time.Second)
}
func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 90
user := &User{
ID: 2,
Email: "test2@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.NotNil(t, claims.IssuedAt)
require.NotNil(t, claims.ExpiresAt)
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
}
......@@ -3,13 +3,13 @@ package service
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// 错误定义
......@@ -156,13 +156,13 @@ func (s *BillingCacheService) cacheWriteWorker() {
case cacheWriteUpdateSubscriptionUsage:
if s.cache != nil {
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
}
}
case cacheWriteDeductBalance:
if s.cache != nil {
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
}
......@@ -216,7 +216,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri
if dropped == 0 {
return
}
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
......@@ -274,7 +274,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
return
}
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err)
}
}
......@@ -302,7 +302,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
}
......@@ -312,7 +312,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
return nil
}
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err
}
return nil
......@@ -396,7 +396,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
return
}
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
}
}
......@@ -425,7 +425,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
}
......@@ -435,7 +435,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil
}
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
......@@ -474,7 +474,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
......@@ -496,7 +496,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
......@@ -585,7 +585,7 @@ func (b *billingCircuitBreaker) Allow() bool {
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
......@@ -612,7 +612,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
......@@ -620,7 +620,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
......@@ -641,9 +641,9 @@ func (b *billingCircuitBreaker) OnSuccess() {
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
......
......@@ -312,7 +312,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
}
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
if err != nil {
return inRangeCost, nil // 出错时返回范围内成本
return inRangeCost, fmt.Errorf("out-range cost: %w", err)
}
// 合并成本
......@@ -388,6 +388,14 @@ type ImagePriceConfig struct {
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
}
// SoraPriceConfig Sora 按次计费配置
type SoraPriceConfig struct {
ImagePrice360 *float64
ImagePrice540 *float64
VideoPricePerRequest *float64
VideoPricePerRequestHD *float64
}
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
......@@ -417,6 +425,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
}
}
// CalculateSoraImageCost 计算 Sora 图片按次费用
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
unitPrice := 0.0
if groupConfig != nil {
switch imageSize {
case "540":
if groupConfig.ImagePrice540 != nil {
unitPrice = *groupConfig.ImagePrice540
}
default:
if groupConfig.ImagePrice360 != nil {
unitPrice = *groupConfig.ImagePrice360
}
}
}
totalCost := unitPrice * float64(imageCount)
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
unitPrice := 0.0
if groupConfig != nil {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
if groupConfig.VideoPricePerRequestHD != nil {
unitPrice = *groupConfig.VideoPricePerRequestHD
}
}
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
unitPrice = *groupConfig.VideoPricePerRequest
}
}
totalCost := unitPrice
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格
......
//go:build unit
package service
import (
"math"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func newTestBillingService() *BillingService {
return NewBillingService(&config.Config{}, nil)
}
func TestCalculateCost_BasicComputation(t *testing.T) {
svc := newTestBillingService()
// 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok
tokens := UsageTokens{
InputTokens: 1000,
OutputTokens: 500,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
// 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075
expectedInput := 1000 * 3e-6
expectedOutput := 500 * 15e-6
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
}
func TestCalculateCost_WithCacheTokens(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 1000,
OutputTokens: 500,
CacheCreationTokens: 2000,
CacheReadTokens: 3000,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
expectedCacheCreation := 2000 * 3.75e-6
expectedCacheRead := 3000 * 0.3e-6
require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10)
require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10)
expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
}
func TestCalculateCost_RateMultiplier(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0)
require.NoError(t, err)
// TotalCost 不受倍率影响,ActualCost 翻倍
require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10)
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
}
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
model string
expectedInput float64
}{
{"claude-opus-4.5-20250101", 5e-6},
{"claude-3-opus-20240229", 15e-6},
{"claude-sonnet-4-20250514", 3e-6},
{"claude-3-5-sonnet-20241022", 3e-6},
{"claude-3-5-haiku-20241022", 1e-6},
{"claude-3-haiku-20240307", 0.25e-6},
}
for _, tt := range tests {
pricing, err := svc.GetModelPricing(tt.model)
require.NoError(t, err, "模型 %s", tt.model)
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model)
}
}
func TestGetModelPricing_CaseInsensitive(t *testing.T) {
svc := newTestBillingService()
p1, err := svc.GetModelPricing("Claude-Sonnet-4")
require.NoError(t, err)
p2, err := svc.GetModelPricing("claude-sonnet-4")
require.NoError(t, err)
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
}
func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
svc := newTestBillingService()
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
pricing, err := svc.GetModelPricing("claude-unknown-model")
require.NoError(t, err)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 50000,
OutputTokens: 1000,
CacheReadTokens: 100000,
}
// 总输入 150k < 200k 阈值,应走正常计费
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
require.NoError(t, err)
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) {
svc := newTestBillingService()
// 缓存 210k + 输入 10k = 220k > 200k 阈值
// 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入
tokens := UsageTokens{
InputTokens: 10000,
OutputTokens: 1000,
CacheReadTokens: 210000,
}
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
require.NoError(t, err)
// 范围内:200k cache + 0 input + 1k output
inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
InputTokens: 0,
OutputTokens: 1000,
CacheReadTokens: 200000,
}, 1.0)
// 范围外:10k cache + 10k input,倍率 2.0
outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
InputTokens: 10000,
CacheReadTokens: 10000,
}, 2.0)
require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) {
svc := newTestBillingService()
// 缓存 100k + 输入 150k = 250k > 200k 阈值
// 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入
tokens := UsageTokens{
InputTokens: 150000,
OutputTokens: 1000,
CacheReadTokens: 100000,
}
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
require.NoError(t, err)
require.True(t, cost.ActualCost > 0, "费用应大于 0")
// 正常费用不含长上下文
normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用")
}
func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
// threshold <= 0 应禁用长上下文计费
cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0)
require.NoError(t, err)
cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10)
}
func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 300000}
// extraMultiplier <= 1 应禁用长上下文计费
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0)
require.NoError(t, err)
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateImageCost(t *testing.T) {
svc := newTestBillingService()
price := 0.134
cfg := &ImagePriceConfig{Price1K: &price}
cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0)
require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10)
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
func TestCalculateSoraVideoCost(t *testing.T) {
svc := newTestBillingService()
price := 0.5
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
}
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
svc := newTestBillingService()
hdPrice := 1.0
normalPrice := 0.5
cfg := &SoraPriceConfig{
VideoPricePerRequest: &normalPrice,
VideoPricePerRequestHD: &hdPrice,
}
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
require.True(t, svc.IsModelSupported("claude-sonnet-4"))
require.True(t, svc.IsModelSupported("Claude-Opus-4.5"))
require.True(t, svc.IsModelSupported("claude-3-haiku"))
require.False(t, svc.IsModelSupported("gpt-4o"))
require.False(t, svc.IsModelSupported("gemini-pro"))
}
func TestCalculateCost_ZeroTokens(t *testing.T) {
svc := newTestBillingService()
cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0)
require.NoError(t, err)
require.Equal(t, 0.0, cost.TotalCost)
require.Equal(t, 0.0, cost.ActualCost)
}
func TestCalculateCostWithConfig(t *testing.T) {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.5
svc := NewBillingService(cfg, nil)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens)
require.NoError(t, err)
expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.5)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateCostWithConfig_ZeroMultiplier(t *testing.T) {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 0
svc := NewBillingService(cfg, nil)
tokens := UsageTokens{InputTokens: 1000}
cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens)
require.NoError(t, err)
// 倍率 <=0 时默认 1.0
expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
}
func TestGetEstimatedCost(t *testing.T) {
svc := newTestBillingService()
est, err := svc.GetEstimatedCost("claude-sonnet-4", 1000, 500)
require.NoError(t, err)
require.True(t, est > 0)
}
func TestListSupportedModels(t *testing.T) {
svc := newTestBillingService()
models := svc.ListSupportedModels()
require.NotEmpty(t, models)
require.GreaterOrEqual(t, len(models), 6)
}
func TestGetPricingServiceStatus_NilService(t *testing.T) {
svc := newTestBillingService()
status := svc.GetPricingServiceStatus()
require.NotNil(t, status)
require.Equal(t, "using fallback", status["last_updated"])
}
func TestForceUpdatePricing_NilService(t *testing.T) {
svc := newTestBillingService()
err := svc.ForceUpdatePricing()
require.Error(t, err)
require.Contains(t, err.Error(), "not initialized")
}
func TestCalculateSoraImageCost(t *testing.T) {
svc := newTestBillingService()
price360 := 0.05
price540 := 0.08
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
}
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
// 使用空的 fallback prices 让 GetModelPricing 失败
svc := &BillingService{
cfg: &config.Config{},
fallbackPrices: make(map[string]*ModelPricing),
}
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
_, err := svc.CalculateCostWithLongContext("unknown-model", tokens, 1.0, 200000, 2.0)
require.Error(t, err)
require.Contains(t, err.Error(), "pricing not found")
}
func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
svc := &BillingService{
cfg: &config.Config{},
fallbackPrices: map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
SupportsCacheBreakdown: true,
CacheCreation5mPrice: 4e-6, // per token
CacheCreation1hPrice: 5e-6, // per token
},
},
}
tokens := UsageTokens{
InputTokens: 1000,
OutputTokens: 500,
CacheCreation5mTokens: 100000,
CacheCreation1hTokens: 50000,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
}
func TestCalculateCost_LargeTokenCount(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 1_000_000,
OutputTokens: 1_000_000,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
// Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15
require.InDelta(t, 3.0, cost.InputCost, 1e-6)
require.InDelta(t, 15.0, cost.OutputCost, 1e-6)
require.False(t, math.IsNaN(cost.TotalCost))
require.False(t, math.IsInf(cost.TotalCost, 0))
}
//go:build unit
package service
import (
"context"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newTestValidator() *ClaudeCodeValidator {
return NewClaudeCodeValidator()
}
// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体
func validClaudeCodeBody() map[string]any {
return map[string]any{
"model": "claude-sonnet-4-20250514",
"system": []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
},
},
"metadata": map[string]any{
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc",
},
}
}
func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
v := newTestValidator()
tests := []struct {
name string
ua string
want bool
}{
{"标准版本号", "claude-cli/1.0.0", true},
{"多位版本号", "claude-cli/12.34.56", true},
{"大写开头", "Claude-CLI/1.0.0", true},
{"非 claude-cli", "curl/7.64.1", false},
{"空 User-Agent", "", false},
{"部分匹配", "not-claude-cli/1.0.0", false},
{"缺少版本号", "claude-cli/", false},
{"版本格式不对", "claude-cli/1.0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua)
})
}
}
func TestValidate_NonMessagesPath_UAOnly(t *testing.T) {
v := newTestValidator()
// 非 messages 路径只检查 UA
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
result := v.Validate(req, nil)
require.True(t, result, "非 messages 路径只需 UA 匹配")
}
func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Set("User-Agent", "curl/7.64.1")
result := v.Validate(req, nil)
require.False(t, result, "UA 不匹配时应返回 false")
}
func TestValidate_MessagesPath_FullValid(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
req.Header.Set("anthropic-version", "2023-06-01")
result := v.Validate(req, validClaudeCodeBody())
require.True(t, result, "完整有效请求应通过")
}
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
v := newTestValidator()
body := validClaudeCodeBody()
tests := []struct {
name string
missingHeader string
}{
{"缺少 X-App", "X-App"},
{"缺少 anthropic-beta", "anthropic-beta"},
{"缺少 anthropic-version", "anthropic-version"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Del(tt.missingHeader)
result := v.Validate(req, body)
require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader)
})
}
}
func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) {
v := newTestValidator()
tests := []struct {
name string
metadata map[string]any
}{
{"缺少 metadata", nil},
{"缺少 user_id", map[string]any{"other": "value"}},
{"空 user_id", map[string]any{"user_id": ""}},
{"格式错误", map[string]any{"user_id": "invalid-format"}},
{"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
body := map[string]any{
"model": "claude-sonnet-4",
"system": []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
},
},
}
if tt.metadata != nil {
body["metadata"] = tt.metadata
}
result := v.Validate(req, body)
require.False(t, result, "metadata.user_id: %v", tt.metadata)
})
}
}
func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
body := map[string]any{
"model": "claude-sonnet-4",
"system": []any{
map[string]any{
"type": "text",
"text": "Generate JSON data for testing database migrations.",
},
},
"metadata": map[string]any{
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc",
},
}
result := v.Validate(req, body)
require.False(t, result, "无关系统提示词应返回 false")
}
func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
// 不设置 X-App 等头,通过 context 标记为 haiku 探测请求
ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
req = req.WithContext(ctx)
// 即使 body 不包含 system prompt,也应通过
result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1})
require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证")
}
func TestSystemPromptSimilarity(t *testing.T) {
v := newTestValidator()
tests := []struct {
name string
prompt string
want bool
}{
{"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true},
{"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true},
{"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true},
{"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true},
{"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true},
{"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true},
{"无关文本", "Write me a poem about cats", false},
{"空文本", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := map[string]any{
"model": "claude-sonnet-4",
"system": []any{
map[string]any{"type": "text", "text": tt.prompt},
},
}
result := v.IncludesClaudeCodeSystemPrompt(body)
require.Equal(t, tt.want, result, "提示词: %q", tt.prompt)
})
}
}
func TestDiceCoefficient(t *testing.T) {
tests := []struct {
name string
a string
b string
want float64
tol float64
}{
{"相同字符串", "hello", "hello", 1.0, 0.001},
{"完全不同", "abc", "xyz", 0.0, 0.001},
{"空字符串", "", "hello", 0.0, 0.001},
{"单字符", "a", "b", 0.0, 0.001},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := diceCoefficient(tt.a, tt.b)
require.InDelta(t, tt.want, result, tt.tol)
})
}
}
func TestIsClaudeCodeClient_Context(t *testing.T) {
ctx := context.Background()
// 默认应为 false
require.False(t, IsClaudeCodeClient(ctx))
// 设置为 true
ctx = SetClaudeCodeClient(ctx, true)
require.True(t, IsClaudeCodeClient(ctx))
// 设置为 false
ctx = SetClaudeCodeClient(ctx, false)
require.False(t, IsClaudeCodeClient(ctx))
}
func TestValidate_NilBody_MessagesPath(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
result := v.Validate(req, nil)
require.False(t, result, "nil body 的 messages 请求应返回 false")
}
......@@ -5,8 +5,9 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// ConcurrencyCache 定义并发控制的缓存接口
......@@ -124,7 +125,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
}
},
}, nil
......@@ -163,7 +164,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
}
},
}, nil
......@@ -191,7 +192,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
if err != nil {
// On error, allow the request to proceed (fail open)
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil
}
return result, nil
......@@ -209,7 +210,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
defer cancel()
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err)
}
}
......@@ -221,7 +222,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
......@@ -237,7 +238,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
......@@ -293,7 +294,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
......@@ -301,7 +302,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
......
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type stubConcurrencyCacheForTest struct {
acquireResult bool
acquireErr error
releaseErr error
concurrency int
concurrencyErr error
waitAllowed bool
waitErr error
waitCount int
waitCountErr error
loadBatch map[int64]*AccountLoadInfo
loadBatchErr error
usersLoadBatch map[int64]*UserLoadInfo
usersLoadErr error
cleanupErr error
// 记录调用
releasedAccountIDs []int64
releasedRequestIDs []string
}
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
return c.acquireResult, c.acquireErr
}
func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error {
c.releasedAccountIDs = append(c.releasedAccountIDs, accountID)
c.releasedRequestIDs = append(c.releasedRequestIDs, requestID)
return c.releaseErr
}
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error {
return nil
}
func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) {
return c.waitCount, c.waitCountErr
}
func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
return c.acquireResult, c.acquireErr
}
func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error {
return c.releaseErr
}
func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error {
return nil
}
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
return c.loadBatch, c.loadBatchErr
}
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
return c.usersLoadBatch, c.usersLoadErr
}
func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return c.cleanupErr
}
func TestAcquireAccountSlot_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
require.NoError(t, err)
require.True(t, result.Acquired)
require.NotNil(t, result.ReleaseFunc)
}
func TestAcquireAccountSlot_Failure(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: false}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
require.NoError(t, err)
require.False(t, result.Acquired)
require.Nil(t, result.ReleaseFunc)
}
func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) {
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
for _, maxConcurrency := range []int{0, -1} {
result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency)
require.NoError(t, err)
require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency)
require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数")
}
}
func TestAcquireAccountSlot_CacheError(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
require.Error(t, err)
require.Nil(t, result)
}
func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 42, 5)
require.NoError(t, err)
require.True(t, result.Acquired)
// 调用 ReleaseFunc 应释放槽位
result.ReleaseFunc()
require.Len(t, cache.releasedAccountIDs, 1)
require.Equal(t, int64(42), cache.releasedAccountIDs[0])
require.Len(t, cache.releasedRequestIDs, 1)
require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空")
}
func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)
// 用户槽位获取应独立于账户槽位
result, err := svc.AcquireUserSlot(context.Background(), 100, 3)
require.NoError(t, err)
require.True(t, result.Acquired)
require.NotNil(t, result.ReleaseFunc)
}
func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
result, err := svc.AcquireUserSlot(context.Background(), 1, 0)
require.NoError(t, err)
require.True(t, result.Acquired)
}
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100},
}
cache := &stubConcurrencyCacheForTest{loadBatch: expected}
svc := NewConcurrencyService(cache)
accounts := []AccountWithConcurrency{
{ID: 1, MaxConcurrency: 5},
{ID: 2, MaxConcurrency: 5},
}
result, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, expected, result)
}
func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
result, err := svc.GetAccountsLoadBatch(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, result)
}
func TestIncrementWaitCount_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err)
require.True(t, allowed)
}
func TestIncrementWaitCount_QueueFull(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitAllowed: false}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err)
require.False(t, allowed)
}
func TestIncrementWaitCount_FailOpen(t *testing.T) {
// Redis 错误时应 fail-open(允许请求通过)
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err, "Redis 错误不应传播")
require.True(t, allowed, "Redis 错误时应 fail-open")
}
func TestIncrementWaitCount_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err)
require.True(t, allowed, "nil cache 应 fail-open")
}
func TestCalculateMaxWait(t *testing.T) {
tests := []struct {
concurrency int
expected int
}{
{5, 25}, // 5 + 20
{1, 21}, // 1 + 20
{0, 21}, // min(1) + 20
{-1, 21}, // min(1) + 20
{10, 30}, // 10 + 20
}
for _, tt := range tests {
result := CalculateMaxWait(tt.concurrency)
require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency)
}
}
func TestGetAccountWaitingCount(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitCount: 5}
svc := NewConcurrencyService(cache)
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, 5, count)
}
func TestGetAccountWaitingCount_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, 0, count)
}
func TestGetAccountConcurrencyBatch(t *testing.T) {
cache := &stubConcurrencyCacheForTest{concurrency: 3}
svc := NewConcurrencyService(cache)
result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3})
require.NoError(t, err)
require.Len(t, result, 3)
for _, id := range []int64{1, 2, 3} {
require.Equal(t, 3, result[id])
}
}
func TestIncrementAccountWaitCount_FailOpen(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
require.NoError(t, err, "Redis 错误不应传播")
require.True(t, allowed, "Redis 错误时应 fail-open")
}
func TestIncrementAccountWaitCount_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
require.NoError(t, err)
require.True(t, allowed)
}
......@@ -3,11 +3,12 @@ package service
import (
"context"
"errors"
"log"
"log/slog"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
......@@ -65,7 +66,7 @@ func (s *DashboardAggregationService) Start() {
return
}
if !s.cfg.Enabled {
log.Printf("[DashboardAggregation] 聚合作业已禁用")
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用")
return
}
......@@ -81,9 +82,9 @@ func (s *DashboardAggregationService) Start() {
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
s.runScheduledAggregation()
})
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
if !s.cfg.BackfillEnabled {
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
}
}
......@@ -93,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return errors.New("聚合服务未初始化")
}
if !s.cfg.BackfillEnabled {
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
return ErrDashboardBackfillDisabled
}
if !end.After(start) {
......@@ -110,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, end); err != nil {
log.Printf("[DashboardAggregation] 回填失败: %v", err)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err)
}
}()
return nil
......@@ -141,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time
return
}
if !errors.Is(err, errDashboardAggregationRunning) {
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err)
return
}
time.Sleep(5 * time.Second)
}
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
}()
return nil
}
......@@ -162,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, now); err != nil {
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err)
return
}
}
......@@ -177,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start,
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
return err
}
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
start.UTC().Format(time.RFC3339),
end.UTC().Format(time.RFC3339),
time.Since(jobStart).String(),
......@@ -198,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
now := time.Now().UTC()
last, err := s.repo.GetAggregationWatermark(ctx)
if err != nil {
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err)
last = time.Unix(0, 0).UTC()
}
......@@ -216,19 +217,19 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
}
if err := s.aggregateRange(ctx, start, now); err != nil {
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err)
return
}
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
if updateErr != nil {
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
}
log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
start.Format(time.RFC3339),
now.Format(time.RFC3339),
time.Since(jobStart).String(),
updateErr == nil,
slog.Debug("[DashboardAggregation] 聚合完成",
"start", start.Format(time.RFC3339),
"end", now.Format(time.RFC3339),
"duration", time.Since(jobStart).String(),
"watermark_updated", updateErr == nil,
)
s.maybeCleanupRetention(ctx, now)
......@@ -261,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start,
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
if updateErr != nil {
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
}
log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
startUTC.Format(time.RFC3339),
endUTC.Format(time.RFC3339),
time.Since(jobStart).String(),
......@@ -279,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start,
return nil
}
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err)
}
return s.repo.AggregateRange(ctx, start, end)
}
......@@ -298,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
}
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
if usageErr != nil {
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
s.lastRetentionCleanup.Store(now)
......
......@@ -5,11 +5,11 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
......@@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return cached, nil
}
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err)
}
}
......@@ -188,7 +188,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() {
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
return
}
s.applyAggregationStatus(ctx, stats)
......@@ -220,12 +220,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u
}
data, err := json.Marshal(entry)
if err != nil {
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err)
return
}
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err)
}
}
......@@ -237,10 +237,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) {
defer cancel()
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err)
}
if reason != nil {
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
}
}
......@@ -271,7 +271,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T
}
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
if err != nil {
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err)
return time.Unix(0, 0).UTC()
}
if updatedAt.IsZero() {
......@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
......
......@@ -24,6 +24,7 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
PlatformSora = domain.PlatformSora
)
// Account type constants
......@@ -160,6 +161,9 @@ const (
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
// =========================
// Stream Timeout Handling
// =========================
......
......@@ -3,9 +3,10 @@ package service
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// Task type constants
......@@ -56,7 +57,7 @@ func (s *EmailQueueService) start() {
s.wg.Add(1)
go s.worker(i)
}
log.Printf("[EmailQueue] Started %d workers", s.workers)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Started %d workers", s.workers)
}
// worker 工作协程
......@@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) {
case task := <-s.taskChan:
s.processTask(id, task)
case <-s.stopChan:
log.Printf("[EmailQueue] Worker %d stopping", id)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d stopping", id)
return
}
}
......@@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
switch task.TaskType {
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
}
......@@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued verify code task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
......@@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
......@@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
func (s *EmailQueueService) Stop() {
close(s.stopChan)
s.wg.Wait()
log.Println("[EmailQueue] All workers stopped")
logger.LegacyPrintf("service.email_queue", "%s", "[EmailQueue] All workers stopped")
}
......@@ -76,7 +76,7 @@ func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
}
account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
_, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
......@@ -157,7 +157,7 @@ func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
}
account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
_, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
......
......@@ -2,13 +2,13 @@ package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
......@@ -72,9 +72,9 @@ func NewErrorPassthroughService(
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.reloadRulesFromDB(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
}
}
......@@ -82,7 +82,7 @@ func NewErrorPassthroughService(
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
......@@ -192,7 +192,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
......@@ -225,7 +225,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
......@@ -288,13 +288,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 先失效缓存,避免后续刷新读到陈旧规则。
if s.cache != nil {
if err := s.cache.Invalidate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to invalidate cache: %v", err)
}
}
// 刷新本地缓存
if err := s.reloadRulesFromDB(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh local cache: %v", err)
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
s.clearLocalCache()
}
......@@ -302,7 +302,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
......
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// --- helpers ---
func testTimePtr(t time.Time) *time.Time { return &t }
func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad {
return accountWithLoad{
account: &Account{
ID: id,
Priority: priority,
LastUsedAt: lastUsed,
Type: accType,
Schedulable: true,
Status: StatusActive,
},
loadInfo: &AccountLoadInfo{
AccountID: id,
CurrentConcurrency: 0,
LoadRate: loadRate,
},
}
}
// --- sortAccountsByPriorityAndLastUsed ---
func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) {
now := time.Now()
accounts := []*Account{
{ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)},
{ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)},
{ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)},
}
sortAccountsByPriorityAndLastUsed(accounts, false)
require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一")
require.Equal(t, int64(3), accounts[1].ID)
require.Equal(t, int64(1), accounts[2].ID)
}
func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) {
now := time.Now()
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)},
{ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
sortAccountsByPriorityAndLastUsed(accounts, false)
require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前")
require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面")
require.Equal(t, int64(1), accounts[2].ID)
}
func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
{ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth},
}
sortAccountsByPriorityAndLastUsed(accounts, true)
require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面")
}
func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
{ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
{ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey},
}
// sortAccountsByPriorityAndLastUsed 内部会在同组(Priority+LastUsedAt)内做随机打散,
// 因此这里不再断言“稳定排序”。我们只验证:
// 1) 元素集合不变;2) 多次运行能产生不同的顺序。
seenFirst := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
sortAccountsByPriorityAndLastUsed(cpy, false)
seenFirst[cpy[0].ID] = true
ids := map[int64]bool{}
for _, a := range cpy {
ids[a.ID] = true
}
require.True(t, ids[1] && ids[2] && ids[3])
}
require.GreaterOrEqual(t, len(seenFirst), 2, "同组账号应能被随机打散")
}
func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) {
now := time.Now()
accounts := []*Account{
{ID: 1, Priority: 2, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)},
{ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))},
{ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))},
}
sortAccountsByPriorityAndLastUsed(accounts, false)
// 优先级1排前:nil < earlier
require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早")
require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在")
// 优先级2排后:nil < time
require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil")
require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间")
}
// --- filterByMinPriority ---
func TestFilterByMinPriority_Empty(t *testing.T) {
result := filterByMinPriority(nil)
require.Nil(t, result)
}
func TestFilterByMinPriority_SelectsMinPriority(t *testing.T) {
accounts := []accountWithLoad{
makeAccWithLoad(1, 5, 10, nil, AccountTypeAPIKey),
makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey),
makeAccWithLoad(3, 1, 20, nil, AccountTypeAPIKey),
makeAccWithLoad(4, 2, 10, nil, AccountTypeAPIKey),
}
result := filterByMinPriority(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(2), result[0].account.ID)
require.Equal(t, int64(3), result[1].account.ID)
}
// --- filterByMinLoadRate ---
func TestFilterByMinLoadRate_Empty(t *testing.T) {
result := filterByMinLoadRate(nil)
require.Nil(t, result)
}
func TestFilterByMinLoadRate_SelectsMinLoadRate(t *testing.T) {
accounts := []accountWithLoad{
makeAccWithLoad(1, 1, 30, nil, AccountTypeAPIKey),
makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey),
makeAccWithLoad(3, 1, 10, nil, AccountTypeAPIKey),
makeAccWithLoad(4, 1, 20, nil, AccountTypeAPIKey),
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(2), result[0].account.ID)
require.Equal(t, int64(3), result[1].account.ID)
}
// --- selectByLRU ---
func TestSelectByLRU_Empty(t *testing.T) {
result := selectByLRU(nil, false)
require.Nil(t, result)
}
func TestSelectByLRU_Single(t *testing.T) {
accounts := []accountWithLoad{makeAccWithLoad(1, 1, 10, nil, AccountTypeAPIKey)}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(1), result.account.ID)
}
func TestSelectByLRU_NilLastUsedAtWins(t *testing.T) {
now := time.Now()
accounts := []accountWithLoad{
makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey),
makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey),
makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey),
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(2), result.account.ID)
}
func TestSelectByLRU_EarliestTimeWins(t *testing.T) {
now := time.Now()
accounts := []accountWithLoad{
makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey),
makeAccWithLoad(2, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey),
makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-2*time.Hour)), AccountTypeAPIKey),
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(3), result.account.ID)
}
func TestSelectByLRU_TiePreferOAuth(t *testing.T) {
now := time.Now()
// 账号 1/2 LastUsedAt 相同,且同为最小值。
accounts := []accountWithLoad{
makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey),
makeAccWithLoad(2, 1, 10, testTimePtr(now), AccountTypeOAuth),
makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(1*time.Hour)), AccountTypeAPIKey),
}
for i := 0; i < 50; i++ {
result := selectByLRU(accounts, true)
require.NotNil(t, result)
require.Equal(t, AccountTypeOAuth, result.account.Type)
require.Equal(t, int64(2), result.account.ID)
}
}
package service
import "testing"
func BenchmarkGatewayService_ParseSSEUsage_MessageStart(b *testing.B) {
svc := &GatewayService{}
data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}`
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage := &ClaudeUsage{}
svc.parseSSEUsage(data, usage)
}
}
func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageStart(b *testing.B) {
svc := &GatewayService{}
data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}`
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage := &ClaudeUsage{}
svc.parseSSEUsagePassthrough(data, usage)
}
}
func BenchmarkGatewayService_ParseSSEUsage_MessageDelta(b *testing.B) {
svc := &GatewayService{}
data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}`
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage := &ClaudeUsage{}
svc.parseSSEUsage(data, usage)
}
}
func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageDelta(b *testing.B) {
svc := &GatewayService{}
data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}`
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage := &ClaudeUsage{}
svc.parseSSEUsagePassthrough(data, usage)
}
}
func BenchmarkParseClaudeUsageFromResponseBody(b *testing.B) {
body := []byte(`{"id":"msg_123","type":"message","usage":{"input_tokens":123,"output_tokens":456,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = parseClaudeUsageFromResponseBody(body)
}
}
package service
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
type anthropicHTTPUpstreamRecorder struct {
lastReq *http.Request
lastBody []byte
resp *http.Response
err error
}
func newAnthropicAPIKeyAccountForTest() *Account {
return &Account{
ID: 201,
Name: "anthropic-apikey-pass-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-anthropic-key",
"base_url": "https://api.anthropic.com",
},
Extra: map[string]any{
"anthropic_passthrough": true,
},
Status: StatusActive,
Schedulable: true,
}
}
func (u *anthropicHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.lastReq = req
if req != nil && req.Body != nil {
b, _ := io.ReadAll(req.Body)
u.lastBody = b
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(b))
}
if u.err != nil {
return nil, u.err
}
return u.resp, nil
}
func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
type streamReadCloser struct {
payload []byte
sent bool
err error
}
func (r *streamReadCloser) Read(p []byte) (int, error) {
if !r.sent {
r.sent = true
n := copy(p, r.payload)
return n, nil
}
if r.err != nil {
return 0, r.err
}
return 0, io.EOF
}
func (r *streamReadCloser) Close() error { return nil }
type failWriteResponseWriter struct {
gin.ResponseWriter
}
func (w *failWriteResponseWriter) Write(data []byte) (int, error) {
return 0, errors.New("client disconnected")
}
func (w *failWriteResponseWriter) WriteString(_ string) (int, error) {
return 0, errors.New("client disconnected")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAndAuthReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("User-Agent", "claude-cli/1.0.0")
c.Request.Header.Set("Authorization", "Bearer inbound-token")
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
c.Request.Header.Set("X-Goog-Api-Key", "inbound-goog-key")
c.Request.Header.Set("Cookie", "secret=1")
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
body := []byte(`{"model":"claude-3-7-sonnet-20250219","stream":true,"system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-3-7-sonnet-20250219",
Stream: true,
}
upstreamSSE := strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":9,"cached_tokens":7}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":3}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"x-request-id": []string{"rid-anthropic-pass"},
"Set-Cookie": []string{"secret=upstream"},
},
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
},
}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
}
account := &Account{
ID: 101,
Name: "anthropic-apikey-pass",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-anthropic-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"claude-3-7-sonnet-20250219": "claude-3-haiku-20240307"},
},
Extra: map[string]any{
"anthropic_passthrough": true,
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
require.Equal(t, "2023-06-01", upstream.lastReq.Header.Get("anthropic-version"))
require.Equal(t, "interleaved-thinking-2025-05-14", upstream.lastReq.Header.Get("anthropic-beta"))
require.Empty(t, upstream.lastReq.Header.Get("x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头")
require.Contains(t, rec.Body.String(), `"cached_tokens":7`)
require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写")
require.Equal(t, 7, result.Usage.CacheReadInputTokens, "计费 usage 解析应保留 cached_tokens 兼容")
require.Empty(t, rec.Header().Get("Set-Cookie"), "响应头应经过安全过滤")
rawBody, ok := c.Get(OpsUpstreamRequestBodyKey)
require.True(t, ok)
bodyBytes, ok := rawBody.([]byte)
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
require.Equal(t, body, bodyBytes)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
c.Request.Header.Set("Authorization", "Bearer inbound-token")
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
c.Request.Header.Set("Cookie", "secret=1")
body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"thinking":{"type":"enabled"}}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-3-5-sonnet-latest",
}
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-count"},
"Set-Cookie": []string{"secret=upstream"},
},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 102,
Name: "anthropic-apikey-pass-count",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-anthropic-key",
"base_url": "https://api.anthropic.com",
"model_mapping": map[string]any{"claude-3-5-sonnet-latest": "claude-3-opus-20240229"},
},
Extra: map[string]any{
"anthropic_passthrough": true,
},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
require.Equal(t, http.StatusOK, rec.Code)
require.JSONEq(t, upstreamRespBody, rec.Body.String())
require.Empty(t, rec.Header().Get("Set-Cookie"))
}
func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
},
},
},
}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "k",
"base_url": "://invalid-url",
},
}
_, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "k")
require.Error(t, err)
}
func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
},
}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{"model":"claude-3-7-sonnet-20250219"}`), "oauth-token", "oauth", "claude-3-7-sonnet-20250219", true, false)
require.NoError(t, err)
require.Equal(t, "Bearer oauth-token", req.Header.Get("authorization"))
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
// Use a canceled context recorder to simulate client disconnect behavior.
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
ctx, cancel := context.WithCancel(req.Context())
cancel()
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 11, result.usage.InputTokens)
require.Equal(t, 5, result.usage.OutputTokens)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":12,"output_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":3},"cached_tokens":4}}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-nonstream"},
},
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
},
}
svc := &GatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 5, result.Usage.CacheCreationInputTokens)
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
require.Equal(t, upstreamJSON, rec.Body.String())
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenType(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
account := &Account{
ID: 202,
Name: "anthropic-oauth",
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "oauth-token",
},
}
svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequestError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
upstream := &anthropicHTTPUpstreamRecorder{
err: errors.New("dial tcp timeout"),
}
svc := &GatewayService{
cfg: &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
},
httpUpstream: upstream,
}
account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed")
require.Equal(t, http.StatusBadGateway, rec.Code)
rawBody, ok := c.Get(OpsUpstreamRequestBodyKey)
require.True(t, ok)
_, ok = rawBody.([]byte)
require.True(t, ok)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"rid-empty-body"}},
Body: nil,
},
}
svc := &GatewayService{
cfg: &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
},
httpUpstream: upstream,
}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "empty response")
}
func TestExtractAnthropicSSEDataLine(t *testing.T) {
t.Run("valid data line with spaces", func(t *testing.T) {
data, ok := extractAnthropicSSEDataLine("data: {\"type\":\"message_start\"}")
require.True(t, ok)
require.Equal(t, `{"type":"message_start"}`, data)
})
t.Run("non data line", func(t *testing.T) {
data, ok := extractAnthropicSSEDataLine("event: message_start")
require.False(t, ok)
require.Empty(t, data)
})
}
func TestGatewayService_ParseSSEUsagePassthrough_MessageStartFallbacks(t *testing.T) {
svc := &GatewayService{}
usage := &ClaudeUsage{}
data := `{"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":9,"cache_creation":{"ephemeral_5m_input_tokens":3,"ephemeral_1h_input_tokens":4}}}}`
svc.parseSSEUsagePassthrough(data, usage)
require.Equal(t, 12, usage.InputTokens)
require.Equal(t, 9, usage.CacheReadInputTokens, "应兼容 cached_tokens 字段")
require.Equal(t, 7, usage.CacheCreationInputTokens, "聚合字段为空时应从 5m/1h 明细回填")
require.Equal(t, 3, usage.CacheCreation5mTokens)
require.Equal(t, 4, usage.CacheCreation1hTokens)
}
func TestGatewayService_ParseSSEUsagePassthrough_MessageDeltaSelectiveOverwrite(t *testing.T) {
svc := &GatewayService{}
usage := &ClaudeUsage{
InputTokens: 10,
CacheCreation5mTokens: 2,
CacheCreation1hTokens: 6,
}
data := `{"type":"message_delta","usage":{"input_tokens":0,"output_tokens":5,"cache_creation_input_tokens":8,"cache_read_input_tokens":0,"cached_tokens":11,"cache_creation":{"ephemeral_5m_input_tokens":1,"ephemeral_1h_input_tokens":0}}}`
svc.parseSSEUsagePassthrough(data, usage)
require.Equal(t, 10, usage.InputTokens, "message_delta 中 0 值不应覆盖已有 input_tokens")
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 8, usage.CacheCreationInputTokens)
require.Equal(t, 11, usage.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退到 cached_tokens")
require.Equal(t, 1, usage.CacheCreation5mTokens)
require.Equal(t, 6, usage.CacheCreation1hTokens, "message_delta 中 0 值不应覆盖已有 1h 明细")
}
func TestGatewayService_ParseSSEUsagePassthrough_NoopCases(t *testing.T) {
svc := &GatewayService{}
usage := &ClaudeUsage{InputTokens: 3}
svc.parseSSEUsagePassthrough("", usage)
require.Equal(t, 3, usage.InputTokens)
svc.parseSSEUsagePassthrough("[DONE]", usage)
require.Equal(t, 3, usage.InputTokens)
svc.parseSSEUsagePassthrough("not-json", usage)
require.Equal(t, 3, usage.InputTokens)
// nil usage 不应 panic
svc.parseSSEUsagePassthrough(`{"type":"message_start"}`, nil)
}
func TestGatewayService_ParseSSEUsagePassthrough_FallbackFromUsageNode(t *testing.T) {
svc := &GatewayService{}
usage := &ClaudeUsage{}
data := `{"type":"content_block_delta","usage":{"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":1}}}`
svc.parseSSEUsagePassthrough(data, usage)
require.Equal(t, 6, usage.CacheReadInputTokens)
require.Equal(t, 3, usage.CacheCreationInputTokens)
}
func TestParseClaudeUsageFromResponseBody(t *testing.T) {
t.Run("empty or missing usage", func(t *testing.T) {
got := parseClaudeUsageFromResponseBody(nil)
require.NotNil(t, got)
require.Equal(t, 0, got.InputTokens)
got = parseClaudeUsageFromResponseBody([]byte(`{"id":"x"}`))
require.NotNil(t, got)
require.Equal(t, 0, got.OutputTokens)
})
t.Run("parse all usage fields and fallback", func(t *testing.T) {
body := []byte(`{"usage":{"input_tokens":21,"output_tokens":34,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":13,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":8}}}`)
got := parseClaudeUsageFromResponseBody(body)
require.Equal(t, 21, got.InputTokens)
require.Equal(t, 34, got.OutputTokens)
require.Equal(t, 13, got.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退 cached_tokens")
require.Equal(t, 13, got.CacheCreationInputTokens, "聚合字段为空时应由 5m/1h 回填")
require.Equal(t, 5, got.CacheCreation5mTokens)
require.Equal(t, 8, got.CacheCreation1hTokens)
})
t.Run("keep explicit aggregate values", func(t *testing.T) {
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"cache_creation_input_tokens":9,"cache_read_input_tokens":7,"cached_tokens":99,"cache_creation":{"ephemeral_5m_input_tokens":4,"ephemeral_1h_input_tokens":5}}}`)
got := parseClaudeUsageFromResponseBody(body)
require.Equal(t, 9, got.CacheCreationInputTokens, "已显式提供聚合字段时不应被明细覆盖")
require.Equal(t, 7, got.CacheReadInputTokens, "已显式提供 cache_read_input_tokens 时不应回退 cached_tokens")
})
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: 32,
},
},
}
// Scanner 初始缓冲为 64KB,构造更长单行触发 bufio.ErrTooLong。
longLine := "data: " + strings.Repeat("x", 80*1024)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(longLine)),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.ErrorIs(t, err, bufio.ErrTooLong)
require.NotNil(t, result)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 1,
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: pr,
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 5}, time.Now(), "claude-3-7-sonnet-20250219")
_ = pw.Close()
_ = pr.Close()
require.Error(t, err)
require.Contains(t, err.Error(), "stream data interval timeout")
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{
err: io.ErrUnexpectedEOF,
},
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 6}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "stream read error")
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 1,
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: pr,
}
done := make(chan struct{})
go func() {
defer close(done)
_, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}` + "\n"))
// 保持上游连接静默,触发数据间隔超时分支。
time.Sleep(1500 * time.Millisecond)
_ = pw.Close()
}()
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 7}, time.Now(), "claude-3-7-sonnet-20250219")
_ = pr.Close()
<-done
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 9, result.usage.InputTokens)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{
err: context.Canceled,
},
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{
payload: []byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":8}}}` + "\n\n"),
err: io.ErrUnexpectedEOF,
},
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 8, result.usage.InputTokens)
}
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/require"
)
type userGroupRateRepoHotpathStub struct {
UserGroupRateRepository
rate *float64
err error
wait <-chan struct{}
calls atomic.Int64
}
func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
s.calls.Add(1)
if s.wait != nil {
<-s.wait
}
if s.err != nil {
return nil, s.err
}
return s.rate, nil
}
type usageLogWindowBatchRepoStub struct {
UsageLogRepository
batchResult map[int64]*usagestats.AccountStats
batchErr error
batchCalls atomic.Int64
singleResult map[int64]*usagestats.AccountStats
singleErr error
singleCalls atomic.Int64
}
func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) {
s.batchCalls.Add(1)
if s.batchErr != nil {
return nil, s.batchErr
}
out := make(map[int64]*usagestats.AccountStats, len(accountIDs))
for _, id := range accountIDs {
if stats, ok := s.batchResult[id]; ok {
out[id] = stats
}
}
return out, nil
}
func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
s.singleCalls.Add(1)
if s.singleErr != nil {
return nil, s.singleErr
}
if stats, ok := s.singleResult[accountID]; ok {
return stats, nil
}
return &usagestats.AccountStats{}, nil
}
type sessionLimitCacheHotpathStub struct {
SessionLimitCache
batchData map[int64]float64
batchErr error
setData map[int64]float64
setErr error
}
func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
if s.batchErr != nil {
return nil, s.batchErr
}
out := make(map[int64]float64, len(accountIDs))
for _, id := range accountIDs {
if v, ok := s.batchData[id]; ok {
out[id] = v
}
}
return out, nil
}
func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
if s.setErr != nil {
return s.setErr
}
if s.setData == nil {
s.setData = make(map[int64]float64)
}
s.setData[accountID] = cost
return nil
}
type modelsListAccountRepoStub struct {
AccountRepository
byGroup map[int64][]Account
all []Account
err error
listByGroupCalls atomic.Int64
listAllCalls atomic.Int64
}
type stickyGatewayCacheHotpathStub struct {
GatewayCache
stickyID int64
getCalls atomic.Int64
}
func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
s.getCalls.Add(1)
if s.stickyID > 0 {
return s.stickyID, nil
}
return 0, errors.New("not found")
}
func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
return nil
}
func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
return nil
}
func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
s.listByGroupCalls.Add(1)
if s.err != nil {
return nil, s.err
}
accounts, ok := s.byGroup[groupID]
if !ok {
return nil, nil
}
out := make([]Account, len(accounts))
copy(out, accounts)
return out, nil
}
func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) {
s.listAllCalls.Add(1)
if s.err != nil {
return nil, s.err
}
out := make([]Account, len(s.all))
copy(out, s.all)
return out, nil
}
func resetGatewayHotpathStatsForTest() {
windowCostPrefetchCacheHitTotal.Store(0)
windowCostPrefetchCacheMissTotal.Store(0)
windowCostPrefetchBatchSQLTotal.Store(0)
windowCostPrefetchFallbackTotal.Store(0)
windowCostPrefetchErrorTotal.Store(0)
userGroupRateCacheHitTotal.Store(0)
userGroupRateCacheMissTotal.Store(0)
userGroupRateCacheLoadTotal.Store(0)
userGroupRateCacheSFSharedTotal.Store(0)
userGroupRateCacheFallbackTotal.Store(0)
modelsListCacheHitTotal.Store(0)
modelsListCacheMissTotal.Store(0)
modelsListCacheStoreTotal.Store(0)
}
func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) {
resetGatewayHotpathStatsForTest()
rate := 1.7
unblock := make(chan struct{})
repo := &userGroupRateRepoHotpathStub{
rate: &rate,
wait: unblock,
}
svc := &GatewayService{
userGroupRateRepo: repo,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
cfg: &config.Config{
Gateway: config.GatewayConfig{
UserGroupRateCacheTTLSeconds: 30,
},
},
}
const concurrent = 12
results := make([]float64, concurrent)
start := make(chan struct{})
var wg sync.WaitGroup
wg.Add(concurrent)
for i := 0; i < concurrent; i++ {
go func(idx int) {
defer wg.Done()
<-start
results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
}(i)
}
close(start)
time.Sleep(20 * time.Millisecond)
close(unblock)
wg.Wait()
for _, got := range results {
require.Equal(t, rate, got)
}
require.Equal(t, int64(1), repo.calls.Load())
// 再次读取应命中缓存,不再回源。
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
require.Equal(t, rate, got)
require.Equal(t, int64(1), repo.calls.Load())
hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats()
require.GreaterOrEqual(t, hit, int64(1))
require.Equal(t, int64(12), miss)
require.Equal(t, int64(1), load)
require.GreaterOrEqual(t, sfShared, int64(1))
require.Equal(t, int64(0), fallback)
}
func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) {
resetGatewayHotpathStatsForTest()
repo := &userGroupRateRepoHotpathStub{
err: errors.New("db down"),
}
svc := &GatewayService{
userGroupRateRepo: repo,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
cfg: &config.Config{
Gateway: config.GatewayConfig{
UserGroupRateCacheTTLSeconds: 30,
},
},
}
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25)
require.Equal(t, 1.25, got)
require.Equal(t, int64(1), repo.calls.Load())
_, _, _, _, fallback := GatewayUserGroupRateCacheStats()
require.Equal(t, int64(1), fallback)
}
func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) {
resetGatewayHotpathStatsForTest()
repo := &userGroupRateRepoHotpathStub{
err: errors.New("should not be called"),
}
svc := &GatewayService{
userGroupRateRepo: repo,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
}
key := "101:202"
svc.userGroupRateCache.Set(key, 2.3, time.Minute)
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1)
require.Equal(t, 2.3, got)
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
require.Equal(t, int64(1), hit)
require.Equal(t, int64(0), miss)
require.Equal(t, int64(0), load)
require.Equal(t, int64(0), fallback)
require.Equal(t, int64(0), repo.calls.Load())
// 无 repo 时直接返回分组默认倍率
svc2 := &GatewayService{
userGroupRateCache: gocache.New(time.Minute, time.Minute),
}
svc2.userGroupRateCache.Set(key, 1.9, time.Minute)
require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4))
require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4))
svc2.userGroupRateCache.Delete(key)
require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4))
}
func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) {
resetGatewayHotpathStatsForTest()
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
windowEnd := windowStart.Add(5 * time.Hour)
accounts := []Account{
{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{"window_cost_limit": 100.0},
SessionWindowStart: &windowStart,
SessionWindowEnd: &windowEnd,
},
{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
Extra: map[string]any{"window_cost_limit": 100.0},
SessionWindowStart: &windowStart,
SessionWindowEnd: &windowEnd,
},
{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{"window_cost_limit": 100.0},
},
}
cache := &sessionLimitCacheHotpathStub{
batchData: map[int64]float64{
1: 11.0,
},
}
repo := &usageLogWindowBatchRepoStub{
batchResult: map[int64]*usagestats.AccountStats{
2: {StandardCost: 22.0},
},
}
svc := &GatewayService{
sessionLimitCache: cache,
usageLogRepo: repo,
}
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
require.NotNil(t, outCtx)
cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1)
require.True(t, ok1)
require.Equal(t, 11.0, cost1)
cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2)
require.True(t, ok2)
require.Equal(t, 22.0, cost2)
_, ok3 := windowCostFromPrefetchContext(outCtx, 3)
require.False(t, ok3)
require.Equal(t, int64(1), repo.batchCalls.Load())
require.Equal(t, 22.0, cache.setData[2])
hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats()
require.Equal(t, int64(1), hit)
require.Equal(t, int64(1), miss)
require.Equal(t, int64(1), batchSQL)
require.Equal(t, int64(0), fallback)
require.Equal(t, int64(0), errCount)
}
func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) {
resetGatewayHotpathStatsForTest()
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
windowEnd := windowStart.Add(5 * time.Hour)
accounts := []Account{
{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{"window_cost_limit": 100.0},
SessionWindowStart: &windowStart,
SessionWindowEnd: &windowEnd,
},
{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
Extra: map[string]any{"window_cost_limit": 100.0},
SessionWindowStart: &windowStart,
SessionWindowEnd: &windowEnd,
},
}
cache := &sessionLimitCacheHotpathStub{
batchData: map[int64]float64{
1: 11.0,
2: 22.0,
},
}
repo := &usageLogWindowBatchRepoStub{}
svc := &GatewayService{
sessionLimitCache: cache,
usageLogRepo: repo,
}
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1)
cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2)
require.True(t, ok1)
require.True(t, ok2)
require.Equal(t, 11.0, cost1)
require.Equal(t, 22.0, cost2)
require.Equal(t, int64(0), repo.batchCalls.Load())
require.Equal(t, int64(0), repo.singleCalls.Load())
hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats()
require.Equal(t, int64(2), hit)
require.Equal(t, int64(0), miss)
require.Equal(t, int64(0), batchSQL)
require.Equal(t, int64(0), fallback)
require.Equal(t, int64(0), errCount)
}
func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) {
resetGatewayHotpathStatsForTest()
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
windowEnd := windowStart.Add(5 * time.Hour)
accounts := []Account{
{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
Extra: map[string]any{"window_cost_limit": 100.0},
SessionWindowStart: &windowStart,
SessionWindowEnd: &windowEnd,
},
}
cache := &sessionLimitCacheHotpathStub{}
repo := &usageLogWindowBatchRepoStub{
batchErr: errors.New("batch failed"),
singleResult: map[int64]*usagestats.AccountStats{
2: {StandardCost: 33.0},
},
}
svc := &GatewayService{
sessionLimitCache: cache,
usageLogRepo: repo,
}
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
cost, ok := windowCostFromPrefetchContext(outCtx, 2)
require.True(t, ok)
require.Equal(t, 33.0, cost)
require.Equal(t, int64(1), repo.batchCalls.Load())
require.Equal(t, int64(1), repo.singleCalls.Load())
_, _, _, fallback, errCount := GatewayWindowCostPrefetchStats()
require.Equal(t, int64(1), fallback)
require.Equal(t, int64(1), errCount)
}
func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) {
resetGatewayHotpathStatsForTest()
groupID := int64(9)
repo := &modelsListAccountRepoStub{
byGroup: map[int64][]Account{
groupID: {
{
ID: 1,
Platform: PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "claude-3-5-sonnet",
"claude-3-5-haiku": "claude-3-5-haiku",
},
},
},
{
ID: 2,
Platform: PlatformGemini,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-2.5-pro": "gemini-2.5-pro",
},
},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1)
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
// TTL 内再次请求应命中缓存,不回源。
models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
require.Equal(t, models1, models2)
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
// 更新仓储数据,但缓存未失效前应继续返回旧值。
repo.byGroup[groupID] = []Account{
{
ID: 3,
Platform: PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-7-sonnet": "claude-3-7-sonnet",
},
},
},
}
models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3)
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic)
models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
require.Equal(t, []string{"claude-3-7-sonnet"}, models4)
require.Equal(t, int64(2), repo.listByGroupCalls.Load())
hit, miss, store := GatewayModelsListCacheStats()
require.Equal(t, int64(2), hit)
require.Equal(t, int64(2), miss)
require.Equal(t, int64(2), store)
}
func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) {
resetGatewayHotpathStatsForTest()
errRepo := &modelsListAccountRepoStub{
err: errors.New("db error"),
}
svcErr := &GatewayService{
accountRepo: errRepo,
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, ""))
okRepo := &modelsListAccountRepoStub{
all: []Account{
{
ID: 1,
Platform: PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "claude-3-5-sonnet",
},
},
},
{
ID: 2,
Platform: PlatformGemini,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-2.5-pro": "gemini-2.5-pro",
},
},
},
},
}
svcOK := &GatewayService{
accountRepo: okRepo,
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
models := svcOK.GetAvailableModels(context.Background(), nil, "")
require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models)
require.Equal(t, int64(1), okRepo.listAllCalls.Load())
}
func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) {
require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil))
cfg := &config.Config{
Gateway: config.GatewayConfig{
UserGroupRateCacheTTLSeconds: 45,
},
}
require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg))
})
t.Run("resolve_models_list_cache_ttl", func(t *testing.T) {
require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil))
cfg := &config.Config{
Gateway: config.GatewayConfig{
ModelsListCacheTTLSeconds: 20,
},
}
require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg))
})
t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) {
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil))
ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil))
groupID := int64(9)
ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456)
ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID)
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID))
ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid")
ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID)
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID))
ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789))
ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID))
})
t.Run("window_cost_from_prefetch_context", func(t *testing.T) {
require.Equal(t, false, func() bool {
_, ok := windowCostFromPrefetchContext(context.TODO(), 0)
return ok
}())
require.Equal(t, false, func() bool {
_, ok := windowCostFromPrefetchContext(context.Background(), 1)
return ok
}())
ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{
9: 12.34,
})
cost, ok := windowCostFromPrefetchContext(ctx, 9)
require.True(t, ok)
require.Equal(t, 12.34, cost)
})
}
func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) {
svc := &GatewayService{
modelsListCache: gocache.New(time.Minute, time.Minute),
}
group9 := int64(9)
group10 := int64(10)
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute)
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute)
svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute)
svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute)
t.Run("invalidate_group_and_platform", func(t *testing.T) {
svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic)
_, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
require.False(t, found)
_, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
require.True(t, stillFound)
})
t.Run("invalidate_group_only", func(t *testing.T) {
svc.InvalidateAvailableModelsCache(&group9, "")
_, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
_, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
require.False(t, foundA)
require.False(t, foundB)
_, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic))
require.True(t, foundOtherGroup)
})
t.Run("invalidate_platform_only", func(t *testing.T) {
// 重建数据后仅按 platform 失效
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute)
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute)
svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute)
svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic)
_, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
_, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic))
_, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
require.False(t, found9Anthropic)
require.False(t, found10Anthropic)
require.True(t, found9Gemini)
})
}
func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
now := time.Now().Add(-time.Minute)
account := Account{
ID: 88,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 4,
Priority: 1,
LastUsedAt: &now,
}
repo := stubOpenAIAccountRepo{accounts: []Account{account}}
concurrency := NewConcurrencyService(stubConcurrencyCache{})
cfg := &config.Config{
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
LoadBatchEnabled: true,
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: time.Second,
FallbackWaitTimeout: time.Second,
FallbackMaxWaiting: 10,
},
},
}
baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic)
t.Run("without_prefetch_reads_cache_once", func(t *testing.T) {
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: concurrency,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(1), cache.getCalls.Load())
})
t.Run("with_prefetch_skips_cache_read", func(t *testing.T) {
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: concurrency,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(0), cache.getCalls.Load())
})
t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) {
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: concurrency,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(1), cache.getCalls.Load())
})
}
......@@ -77,6 +77,11 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, nil
}
......
......@@ -5,9 +5,28 @@ import (
"encoding/json"
"fmt"
"math"
"unsafe"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
// 这些字节模式用于 fast-path 判断,避免每次 []byte("...") 产生临时分配。
patternTypeThinking = []byte(`"type":"thinking"`)
patternTypeThinkingSpaced = []byte(`"type": "thinking"`)
patternTypeRedactedThinking = []byte(`"type":"redacted_thinking"`)
patternTypeRedactedSpaced = []byte(`"type": "redacted_thinking"`)
patternThinkingField = []byte(`"thinking":`)
patternThinkingFieldSpaced = []byte(`"thinking" :`)
patternEmptyContent = []byte(`"content":[]`)
patternEmptyContentSpaced = []byte(`"content": []`)
patternEmptyContentSp1 = []byte(`"content" : []`)
patternEmptyContentSp2 = []byte(`"content" :[]`)
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
......@@ -48,113 +67,127 @@ type ParsedRequest struct {
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
// 保持与旧实现一致:请求体必须是合法 JSON。
// 注意:gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。
if !gjson.ValidBytes(body) {
return nil, fmt.Errorf("invalid json")
}
// 性能:
// - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string(对于巨大 messages 会产生额外拷贝)。
// - 这里将 body 通过 unsafe 零拷贝视为 string,仅在本函数内使用,且 body 不会被修改。
jsonStr := *(*string)(unsafe.Pointer(&body))
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
// --- gjson 提取简单字段(避免完整 Unmarshal) ---
// model: 需要严格类型校验,非 string 返回错误
modelResult := gjson.Get(jsonStr, "model")
if modelResult.Exists() {
if modelResult.Type != gjson.String {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
parsed.Model = modelResult.String()
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
// stream: 需要严格类型校验,非 bool 返回错误
streamResult := gjson.Get(jsonStr, "stream")
if streamResult.Exists() {
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
parsed.Stream = streamResult.Bool()
}
// metadata.user_id: 直接路径提取,不需要严格类型校验
parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String()
// thinking.type: enabled/adaptive 都视为开启
thinkingType := gjson.Get(jsonStr, "thinking.type").String()
if thinkingType == "enabled" || thinkingType == "adaptive" {
parsed.ThinkingEnabled = true
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
// max_tokens: 仅接受整数值
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
f := maxTokensResult.Float()
if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) &&
f <= float64(math.MaxInt) && f >= float64(math.MinInt) {
parsed.MaxTokens = int(f)
}
}
// --- system/messages 提取 ---
// 避免把整个 body Unmarshal 到 map(会产生大量 map/接口分配)。
// 使用 gjson 抽取目标字段的 Raw,再对该子树进行 Unmarshal。
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() {
var parts []any
if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil {
return nil, err
}
parsed.System = parts
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() {
var msgs []any
if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil {
return nil, err
}
parsed.Messages = msgs
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
if sys := gjson.Get(jsonStr, "system"); sys.Exists() {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
// thinking: {type: "enabled" | "adaptive"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") {
parsed.ThinkingEnabled = true
switch sys.Type {
case gjson.Null:
parsed.System = nil
case gjson.String:
// 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。
parsed.System = sys.String()
default:
var system any
if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil {
return nil, err
}
parsed.System = system
}
}
}
// max_tokens
if rawMaxTokens, exists := req["max_tokens"]; exists {
if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok {
parsed.MaxTokens = maxTokens
if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() {
var messages []any
if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil {
return nil, err
}
parsed.Messages = messages
}
}
return parsed, nil
}
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
func parseIntegralNumber(raw any) (int, bool) {
switch v := raw.(type) {
case float64:
if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) {
return 0, false
}
if v > float64(math.MaxInt) || v < float64(math.MinInt) {
return 0, false
}
return int(v), true
case int:
return v, true
case int8:
return int(v), true
case int16:
return int(v), true
case int32:
return int(v), true
case int64:
if v > int64(math.MaxInt) || v < int64(math.MinInt) {
return 0, false
// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。
// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages)产生额外拷贝。
// 当 Index 不可用时,退化为复制(理论上极少发生)。
func sliceRawFromBody(body []byte, r gjson.Result) []byte {
if r.Index > 0 {
end := r.Index + len(r.Raw)
if end <= len(body) {
return body[r.Index:end]
}
return int(v), true
case json.Number:
i64, err := v.Int64()
if err != nil {
return 0, false
}
if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) {
return 0, false
}
return int(i64), true
default:
return 0, false
}
// fallback: 不影响正确性,但会产生一次拷贝
return []byte(r.Raw)
}
// FilterThinkingBlocks removes thinking blocks from request body
......@@ -184,49 +217,63 @@ func FilterThinkingBlocks(body []byte) []byte {
// - Remove `redacted_thinking` blocks (cannot be converted to text).
// - Ensure no message ends up with empty content.
func FilterThinkingBlocksForRetry(body []byte) []byte {
hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) ||
bytes.Contains(body, []byte(`"type": "thinking"`)) ||
bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"thinking":`)) ||
bytes.Contains(body, []byte(`"thinking" :`))
hasThinkingContent := bytes.Contains(body, patternTypeThinking) ||
bytes.Contains(body, patternTypeThinkingSpaced) ||
bytes.Contains(body, patternTypeRedactedThinking) ||
bytes.Contains(body, patternTypeRedactedSpaced) ||
bytes.Contains(body, patternThinkingField) ||
bytes.Contains(body, patternThinkingFieldSpaced)
// Also check for empty content arrays that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) ||
bytes.Contains(body, []byte(`"content": []`)) ||
bytes.Contains(body, []byte(`"content" : []`)) ||
bytes.Contains(body, []byte(`"content" :[]`))
hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
bytes.Contains(body, patternEmptyContentSpaced) ||
bytes.Contains(body, patternEmptyContentSp1) ||
bytes.Contains(body, patternEmptyContentSp2)
// Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
// 尽量避免把整个 body Unmarshal 成 map(会产生大量 map/接口分配)。
// 这里先用 gjson 把 messages 子树摘出来,后续只对 messages 做 Unmarshal/Marshal。
jsonStr := *(*string)(unsafe.Pointer(&body))
msgsRes := gjson.Get(jsonStr, "messages")
if !msgsRes.Exists() || !msgsRes.IsArray() {
return body
}
modified := false
messages, ok := req["messages"].([]any)
if !ok {
// Fast path:只需要删除顶层 thinking,不需要改 messages。
// 注意:patternThinkingField 可能来自嵌套字段(如 tool_use.input.thinking),因此必须用 gjson 判断顶层字段是否存在。
containsThinkingBlocks := bytes.Contains(body, patternTypeThinking) ||
bytes.Contains(body, patternTypeThinkingSpaced) ||
bytes.Contains(body, patternTypeRedactedThinking) ||
bytes.Contains(body, patternTypeRedactedSpaced) ||
bytes.Contains(body, patternThinkingFieldSpaced)
if !hasEmptyContent && !containsThinkingBlocks {
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
return out
}
return body
}
return body
}
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
var messages []any
if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil {
return body
}
newMessages := make([]any, 0, len(messages))
modified := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
deleteTopLevelThinking := gjson.Get(jsonStr, "thinking").Exists()
for i := 0; i < len(messages); i++ {
msgMap, ok := messages[i].(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
......@@ -234,17 +281,30 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
content, ok := msgMap["content"].([]any)
if !ok {
// String content or other format - keep as is
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
// 延迟分配:只有检测到需要修改的块,才构建新 slice。
var newContent []any
modifiedThisMsg := false
for _, block := range content {
ensureNewContent := func(prefixLen int) {
if newContent != nil {
return
}
newContent = make([]any, 0, len(content))
if prefixLen > 0 {
newContent = append(newContent, content[:prefixLen]...)
}
}
for bi := 0; bi < len(content); bi++ {
block := content[bi]
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
if newContent != nil {
newContent = append(newContent, block)
}
continue
}
......@@ -254,17 +314,15 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
switch blockType {
case "thinking":
modifiedThisMsg = true
ensureNewContent(bi)
thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
if thinkingText != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText})
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": thinkingText,
})
continue
case "redacted_thinking":
modifiedThisMsg = true
ensureNewContent(bi)
continue
}
......@@ -272,6 +330,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
ensureNewContent(bi)
switch v := rawThinking.(type) {
case string:
if v != "" {
......@@ -286,40 +345,64 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
}
}
newContent = append(newContent, block)
if newContent != nil {
newContent = append(newContent, block)
}
}
// Handle empty content: either from filtering or originally empty
if newContent == nil {
if len(content) == 0 {
modified = true
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}}
}
continue
}
if len(newContent) == 0 {
modified = true
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": placeholder,
})
msgMap["content"] = newContent
} else if modifiedThisMsg {
msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}}
continue
}
if modifiedThisMsg {
modified = true
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if modified {
req["messages"] = newMessages
} else {
if !modified && !deleteTopLevelThinking {
// Avoid rewriting JSON when no changes are needed.
return body
}
newBody, err := json.Marshal(req)
if err != nil {
return body
out := body
if deleteTopLevelThinking {
if b, err := sjson.DeleteBytes(out, "thinking"); err == nil {
out = b
} else {
return body
}
}
return newBody
if modified {
msgsBytes, err := json.Marshal(messages)
if err != nil {
return body
}
out, err = sjson.SetRawBytes(out, "messages", msgsBytes)
if err != nil {
return body
}
}
return out
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
......
//go:build unit
package service
import (
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
......@@ -434,3 +438,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
require.Contains(t, content0["text"], "tool_use")
require.Contains(t, content1["text"], "tool_result")
}
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
// Task 7.1 — 类型校验边界测试
func TestParseGatewayRequest_TypeValidation(t *testing.T) {
tests := []struct {
name string
body string
wantErr bool
errSubstr string // 期望的错误信息子串(为空则不检查)
}{
{
name: "model 为 int",
body: `{"model":123}`,
wantErr: true,
errSubstr: "invalid model field type",
},
{
name: "model 为 array",
body: `{"model":[]}`,
wantErr: true,
errSubstr: "invalid model field type",
},
{
name: "model 为 bool",
body: `{"model":true}`,
wantErr: true,
errSubstr: "invalid model field type",
},
{
name: "model 为 null — gjson Null 类型触发类型校验错误",
body: `{"model":null}`,
wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误
errSubstr: "invalid model field type",
},
{
name: "stream 为 string",
body: `{"stream":"true"}`,
wantErr: true,
errSubstr: "invalid stream field type",
},
{
name: "stream 为 int",
body: `{"stream":1}`,
wantErr: true,
errSubstr: "invalid stream field type",
},
{
name: "stream 为 null — gjson Null 类型触发类型校验错误",
body: `{"stream":null}`,
wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误
errSubstr: "invalid stream field type",
},
{
name: "model 为 object",
body: `{"model":{}}`,
wantErr: true,
errSubstr: "invalid model field type",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseGatewayRequest([]byte(tt.body), "")
if tt.wantErr {
require.Error(t, err)
if tt.errSubstr != "" {
require.Contains(t, err.Error(), tt.errSubstr)
}
} else {
require.NoError(t, err)
}
})
}
}
// Task 7.2 — 可选字段缺失测试
func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantStream bool
wantMetadataUID string
wantHasSystem bool
wantThinking bool
wantMaxTokens int
wantMessagesNil bool
wantMessagesLen int
}{
{
name: "完全空 JSON — 所有字段零值",
body: `{}`,
wantModel: "",
wantStream: false,
wantMetadataUID: "",
wantHasSystem: false,
wantThinking: false,
wantMaxTokens: 0,
wantMessagesNil: true,
},
{
name: "metadata 无 user_id",
body: `{"model":"test"}`,
wantModel: "test",
wantMetadataUID: "",
wantHasSystem: false,
wantThinking: false,
},
{
name: "thinking 非 enabled(type=disabled)",
body: `{"model":"test","thinking":{"type":"disabled"}}`,
wantModel: "test",
wantThinking: false,
},
{
name: "thinking 字段缺失",
body: `{"model":"test"}`,
wantModel: "test",
wantThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
require.NoError(t, err)
require.Equal(t, tt.wantModel, parsed.Model)
require.Equal(t, tt.wantStream, parsed.Stream)
require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID)
require.Equal(t, tt.wantHasSystem, parsed.HasSystem)
require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled)
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
if tt.wantMessagesNil {
require.Nil(t, parsed.Messages)
}
if tt.wantMessagesLen > 0 {
require.Len(t, parsed.Messages, tt.wantMessagesLen)
}
})
}
}
// Task 7.3 — Gemini 协议分支测试
// 已有测试覆盖:
// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents
// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents
// - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction)
// 因此跳过。
// Task 7.4 — max_tokens 边界测试
func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) {
tests := []struct {
name string
body string
wantMaxTokens int
wantErr bool
}{
{
name: "正常整数",
body: `{"max_tokens":1024}`,
wantMaxTokens: 1024,
},
{
name: "浮点数(非整数)被忽略",
body: `{"max_tokens":10.5}`,
wantMaxTokens: 0,
},
{
name: "负整数可以通过",
body: `{"max_tokens":-1}`,
wantMaxTokens: -1,
},
{
name: "超大值不 panic",
body: `{"max_tokens":9999999999999999}`,
wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16
},
{
name: "null 值被忽略",
body: `{"max_tokens":null}`,
wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
})
}
}
// ============ Task 7.5: Benchmark 测试 ============
// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。
// 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。
func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) {
parsed := &ParsedRequest{
Body: body,
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
// model
if raw, ok := req["model"]; ok {
s, ok := raw.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = s
}
// stream
if raw, ok := req["stream"]; ok {
b, ok := raw.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = b
}
// metadata.user_id
if meta, ok := req["metadata"].(map[string]any); ok {
if uid, ok := meta["user_id"].(string); ok {
parsed.MetadataUserID = uid
}
}
// thinking.type
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
parsed.ThinkingEnabled = true
}
}
// max_tokens
if raw, ok := req["max_tokens"]; ok {
if n, ok := parseIntegralNumber(raw); ok {
parsed.MaxTokens = n
}
}
// system / messages(按协议分支)
switch protocol {
case domain.PlatformGemini:
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
}
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
return parsed, nil
}
// buildSmallJSON 构建 ~500B 的小型测试 JSON
func buildSmallJSON() []byte {
return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`)
}
// buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages)
func buildLargeJSON() []byte {
var b strings.Builder
b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`)
msgCount := 200
for i := 0; i < msgCount; i++ {
if i > 0 {
b.WriteByte(',')
}
if i%2 == 0 {
b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i))
} else {
b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i))
}
}
b.WriteString(`]}`)
return []byte(b.String())
}
func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) {
data := buildSmallJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = parseGatewayRequestOld(data, "")
}
}
func BenchmarkParseGatewayRequest_New_Small(b *testing.B) {
data := buildSmallJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseGatewayRequest(data, "")
}
}
func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
data := buildLargeJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = parseGatewayRequestOld(data, "")
}
}
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
data := buildLargeJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseGatewayRequest(data, "")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment