Commit 5e060b22 authored by erio's avatar erio
Browse files

Merge remote-tracking branch 'upstream/main' into feat/channel-insights

# Conflicts:
#	backend/cmd/server/wire_gen.go
parents 6f04c25e 0a80ec80
...@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t ...@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
}, },
} }
snapshot := svc.snapshotFromAPIKey(apiKey) snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip) require.NotNil(t, roundTrip)
......
...@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
grantPlan := s.resolveSignupGrantPlan(ctx, "email") grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户 // 创建用户
user := &User{ user := &User{
Email: email, Email: email,
...@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Role: RoleUser, Role: RoleUser,
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive, Status: StatusActive,
} }
...@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
signupSource := inferLegacySignupSource(email) signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{ newUser := &User{
Email: email, Email: email,
...@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Role: RoleUser, Role: RoleUser,
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource, SignupSource: signupSource,
} }
...@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
signupSource := inferLegacySignupSource(email) signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{ newUser := &User{
Email: email, Email: email,
...@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Role: RoleUser, Role: RoleUser,
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource, SignupSource: signupSource,
} }
......
...@@ -20,6 +20,9 @@ import ( ...@@ -20,6 +20,9 @@ import (
var ( var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
...@@ -87,6 +90,8 @@ type BillingCacheService struct { ...@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader apiKeyRateLimitLoader apiKeyRateLimitLoader
userRPMCache UserRPMCache
userGroupRateRepo UserGroupRateRepository
cfg *config.Config cfg *config.Config
circuitBreaker *billingCircuitBreaker circuitBreaker *billingCircuitBreaker
...@@ -104,12 +109,22 @@ type BillingCacheService struct { ...@@ -104,12 +109,22 @@ type BillingCacheService struct {
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { func NewBillingCacheService(
cache BillingCache,
userRepo UserRepository,
subRepo UserSubscriptionRepository,
apiKeyRepo APIKeyRepository,
userRPMCache UserRPMCache,
userGroupRateRepo UserGroupRateRepository,
cfg *config.Config,
) *BillingCacheService {
svc := &BillingCacheService{ svc := &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo, apiKeyRateLimitLoader: apiKeyRepo,
userRPMCache: userRPMCache,
userGroupRateRepo: userGroupRateRepo,
cfg: cfg, cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
...@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
} }
} }
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
if err := s.checkRPM(ctx, user, group); err != nil {
return err
}
return nil
}
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
//
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
//
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
if s == nil || s.userRPMCache == nil || user == nil {
return nil
}
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
if group != nil {
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
var override *int
if user.UserGroupRPMOverride != nil {
override = user.UserGroupRPMOverride
} else if s.userGroupRateRepo != nil {
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm override lookup failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
} else {
override = dbOverride
}
}
if override != nil {
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
if *override > 0 {
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if incErr != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
user.ID, group.ID, incErr,
)
// fail-open
} else if count > *override {
return ErrGroupRPMExceeded
}
}
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
} else if group.RPMLimit > 0 {
// 无 override,检查 group.rpm_limit。
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
// fail-open
} else if count > group.RPMLimit {
return ErrGroupRPMExceeded
}
}
}
// ── 第二层:用户级全局硬上限(始终生效) ──
if user.RPMLimit > 0 {
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (user) failed for user=%d: %v",
user.ID, err,
)
return nil // fail-open
}
if count > user.RPMLimit {
return ErrUserRPMExceeded
}
}
return nil return nil
} }
......
//go:build unit
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
type userRPMCacheStub struct {
userGroupCalls int32
userCalls int32
userGroupCounts []int // 依次返回的计数值
userGroupErr error
userCounts []int
userErr error
}
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
if s.userGroupErr != nil {
return 0, s.userGroupErr
}
if idx < len(s.userGroupCounts) {
return s.userGroupCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
if s.userErr != nil {
return 0, s.userErr
}
if idx < len(s.userCounts) {
return s.userCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
return 0, nil
}
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
return 0, nil
}
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
type rpmOverrideRepoStub struct {
UserGroupRateRepository
override *int
err error
calls int32
}
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
atomic.AddInt32(&s.calls, 1)
if s.err != nil {
return nil, s.err
}
return s.override, nil
}
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
t.Helper()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
t.Cleanup(svc.Stop)
return svc
}
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
override := 2
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
override := 100 // override 很高
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
}
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
zero := 0
// user 计数: 依次返回 1..6
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 5}
group := &Group{ID: 10, RPMLimit: 100}
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
for i := 0; i < 5; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
}
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
"override=0 跳过分组但 user 全局上限仍应生效")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
}
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
zero := 0
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0} // user 也不限
group := &Group{ID: 10, RPMLimit: 100}
for i := 0; i < 50; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
}
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
group := &Group{ID: 10, RPMLimit: 5}
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
}
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 10}
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 0}
for i := 0; i < 10; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 5}
// Redis 故障时应 fail-open,不拒绝请求
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
}
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
// 无 group(纯用户级限流场景),不应查询 rpm_override。
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
}
...@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond, delay: 80 * time.Millisecond,
balance: 12.34, balance: 12.34,
} }
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop) t.Cleanup(svc.Stop)
const goroutines = 16 const goroutines = 16
......
...@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, ...@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{} cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop) t.Cleanup(svc.Stop)
start := time.Now() start := time.Now()
...@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { ...@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{} cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc.Stop() svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{ enqueued := svc.enqueueCacheWrite(cacheWriteTask{
......
...@@ -217,6 +217,9 @@ func (s *BillingService) initFallbackPricing() { ...@@ -217,6 +217,9 @@ func (s *BillingService) initFallbackPricing() {
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
} }
// GPT-5.5 暂无独立定价,回退到 GPT-5.4
s.fallbackPrices["gpt-5.5"] = s.fallbackPrices["gpt-5.4"]
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
InputPricePerToken: 7.5e-7, InputPricePerToken: 7.5e-7,
OutputPricePerToken: 4.5e-6, OutputPricePerToken: 4.5e-6,
...@@ -288,6 +291,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { ...@@ -288,6 +291,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower) normalized := normalizeCodexModel(modelLower)
switch normalized { switch normalized {
case "gpt-5.5":
return s.fallbackPrices["gpt-5.5"]
case "gpt-5.4-mini": case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"] return s.fallbackPrices["gpt-5.4-mini"]
case "gpt-5.4": case "gpt-5.4":
...@@ -637,7 +642,8 @@ func isOpenAIGPT54Model(model string) bool { ...@@ -637,7 +642,8 @@ func isOpenAIGPT54Model(model string) bool {
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
return false return false
} }
return normalizeCodexModel(trimmed) == "gpt-5.4" normalized := normalizeCodexModel(trimmed)
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
} }
// CalculateCostWithConfig 使用配置中的默认倍率计算费用 // CalculateCostWithConfig 使用配置中的默认倍率计算费用
......
...@@ -170,9 +170,10 @@ const ( ...@@ -170,9 +170,10 @@ const (
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
// 第三方认证来源默认授予配置 // 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
......
...@@ -59,6 +59,10 @@ type Group struct { ...@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel string DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
RPMLimit int
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
......
package service
import "context"
// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。
type OpenAI403CounterCache interface {
// IncrementOpenAI403Count 原子递增 403 计数并返回当前值。
IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
// ResetOpenAI403Count 成功后清零计数器。
ResetOpenAI403Count(ctx context.Context, accountID int64) error
}
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
) )
var codexModelMap = map[string]string{ var codexModelMap = map[string]string{
"gpt-5.5": "gpt-5.5",
"gpt-5.4": "gpt-5.4", "gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini", "gpt-5.4-mini": "gpt-5.4-mini",
"gpt-5.4-none": "gpt-5.4", "gpt-5.4-none": "gpt-5.4",
...@@ -207,6 +208,9 @@ func normalizeCodexModel(model string) string { ...@@ -207,6 +208,9 @@ func normalizeCodexModel(model string) string {
normalized := strings.ToLower(modelID) normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") {
return "gpt-5.5"
}
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
return "gpt-5.4-mini" return "gpt-5.4-mini"
} }
......
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type openAI403CounterResetStub struct {
resetCalls []int64
}
func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
return 0, nil
}
func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{
rateLimitService: rateLimitSvc,
}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{},
Account: &Account{ID: 777, Platform: PlatformOpenAI},
})
require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls)
}
...@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing. ...@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.
require.NotNil(t, usageRepo.lastLog.BillingMode) require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
} }
func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
imagePrice := 0.02
groupID := int64(12)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_per_request",
Model: "gpt-image-2",
Usage: OpenAIUsage{
InputTokens: 1110,
OutputTokens: 1756,
ImageOutputTokens: 1756,
},
ImageCount: 2,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1008,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 1.0,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 2008},
Account: &Account{ID: 3008},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
}
...@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct { ...@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result result := input.Result
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
}
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
...@@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( ...@@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
serviceTier string, serviceTier string,
) (*CostBreakdown, error) { ) (*CostBreakdown, error) {
if result != nil && result.ImageCount > 0 { if result != nil && result.ImageCount > 0 {
if hasOpenAIImageUsageTokens(result) {
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
if err == nil {
return cost, nil
}
}
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
} }
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
...@@ -4646,32 +4643,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( ...@@ -4646,32 +4643,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
} }
func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost(
ctx context.Context,
apiKey *APIKey,
billingModel string,
multiplier float64,
tokens UsageTokens,
serviceTier string,
sizeTier string,
) (*CostBreakdown, error) {
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
return s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
SizeTier: sizeTier,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
}
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
func (s *OpenAIGatewayService) calculateOpenAIImageCost( func (s *OpenAIGatewayService) calculateOpenAIImageCost(
ctx context.Context, ctx context.Context,
billingModel string, billingModel string,
...@@ -4679,7 +4650,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( ...@@ -4679,7 +4650,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
result *OpenAIForwardResult, result *OpenAIForwardResult,
multiplier float64, multiplier float64,
) *CostBreakdown { ) *CostBreakdown {
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil { if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
(resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
gid := apiKey.Group.ID gid := apiKey.Group.ID
cost, err := s.billingService.CalculateCostUnified(CostInput{ cost, err := s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx, Ctx: ctx,
...@@ -4720,17 +4692,6 @@ func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, ...@@ -4720,17 +4692,6 @@ func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context,
return nil return nil
} }
func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool {
if result == nil {
return false
}
return result.Usage.InputTokens > 0 ||
result.Usage.OutputTokens > 0 ||
result.Usage.CacheCreationInputTokens > 0 ||
result.Usage.CacheReadInputTokens > 0 ||
result.Usage.ImageOutputTokens > 0
}
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. // ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses. // Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
......
This diff is collapsed.
This diff is collapsed.
...@@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { ...@@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
} }
} }
// GPT-5.5 回退到 GPT-5.4 定价
if strings.HasPrefix(model, "gpt-5.5") {
logger.With(zap.String("component", "service.pricing")).
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)"))
return openAIGPT54FallbackPricing
}
if strings.HasPrefix(model, "gpt-5.4-mini") { if strings.HasPrefix(model, "gpt-5.4-mini") {
logger.With(zap.String("component", "service.pricing")). logger.With(zap.String("component", "service.pricing")).
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)")) Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)"))
......
...@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct { ...@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct {
updateCredentialsCalls int updateCredentialsCalls int
lastCredentials map[string]any lastCredentials map[string]any
lastErrorMsg string lastErrorMsg string
lastTempReason string
} }
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
...@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error ...@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++ r.tempCalls++
r.lastTempReason = reason
return nil return nil
} }
...@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct { ...@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct {
err error err error
} }
type openAI403CounterCacheStub struct {
counts []int64
resetCalls []int64
err error
}
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
if s.err != nil {
return 0, s.err
}
if len(s.counts) == 0 {
return 1, nil
}
count := s.counts[0]
s.counts = s.counts[1:]
return count, nil
}
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account) r.accounts = append(r.accounts, account)
return r.err return r.err
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment