Commit dc5d42ad authored by james-6-23's avatar james-6-23
Browse files

feat(rpm): RPM 限流模块优化

P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
parent ef967d8f
......@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet(
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
NewUserRPMCache,
NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,
......
......@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
"role": "user",
"balance": 12.5,
"concurrency": 5,
"rpm_limit": 0,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
......@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_group_id_on_invalid_request": null,
"require_oauth_only": false,
"require_privacy_set": false,
"rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
......@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
......@@ -889,6 +892,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
......@@ -1084,7 +1088,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......
......@@ -221,6 +221,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
......@@ -244,6 +245,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
......
......@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"sort"
"strings"
"time"
......@@ -32,6 +33,7 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
......@@ -50,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
......@@ -114,6 +118,7 @@ type CreateUserInput struct {
Notes string
Balance float64
Concurrency int
RPMLimit int
AllowedGroups []int64
}
......@@ -124,6 +129,7 @@ type UpdateUserInput struct {
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
......@@ -199,6 +205,8 @@ type CreateGroupInput struct {
RequireOAuthOnly bool
RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制)
RPMLimit int
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
......@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly *bool
RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
RPMLimit *int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
......@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys int64 // 迁移的 Key 数量
}
// UserRPMStatus describes a user's current per-minute RPM usage.
type UserRPMStatus struct {
UserRPMUsed int `json:"user_rpm_used"`
UserRPMLimit int `json:"user_rpm_limit"`
PerGroup []UserGroupRPMStatus `json:"per_group"`
}
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
type UserGroupRPMStatus struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Used int `json:"used"`
Limit int `json:"limit"`
Source string `json:"source"` // "group" | "override"
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
......@@ -463,6 +489,8 @@ const (
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
......@@ -472,6 +500,7 @@ type adminServiceImpl struct {
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
userRPMCache UserRPMCache
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
......@@ -496,6 +525,7 @@ func NewAdminService(
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
userRPMCache UserRPMCache,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
......@@ -514,6 +544,7 @@ func NewAdminService(
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
userRPMCache: userRPMCache,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
......@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
RPMLimit: input.RPMLimit,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
......@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
oldRPMLimit := user.RPMLimit
if input.Email != "" {
user.Email = input.Email
......@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Concurrency = *input.Concurrency
}
if input.RPMLimit != nil {
user.RPMLimit = *input.RPMLimit
}
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
......@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
......@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
if s.userRPMCache == nil {
return nil, ErrRPMStatusUnavailable
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
}
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
if err != nil {
return nil, err
}
groupIDSet := make(map[int64]struct{})
for _, key := range keys {
if key.GroupID != nil && *key.GroupID > 0 {
groupIDSet[*key.GroupID] = struct{}{}
}
}
groupIDs := make([]int64, 0, len(groupIDSet))
for groupID := range groupIDSet {
groupIDs = append(groupIDs, groupID)
}
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
var perGroup []UserGroupRPMStatus
for _, groupID := range groupIDs {
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
if getErr != nil {
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
}
entry := UserGroupRPMStatus{
GroupID: groupID,
Used: used,
}
if s.groupRepo != nil {
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
entry.GroupName = group.Name
entry.Limit = group.RPMLimit
entry.Source = "group"
} else if groupErr != nil {
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
}
}
if s.userGroupRateRepo != nil {
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
if overrideErr != nil {
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
} else if override != nil {
entry.Limit = *override
entry.Source = "override"
}
}
perGroup = append(perGroup, entry)
}
return &UserRPMStatus{
UserRPMUsed: userRPMUsed,
UserRPMLimit: user.RPMLimit,
PerGroup: perGroup,
}, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]any{
......@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
RPMLimit: input.RPMLimit,
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil {
......@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
if input.RPMLimit != nil {
group.RPMLimit = *input.RPMLimit
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
......@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
......@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
if s.userGroupRateRepo == nil {
return nil
}
for _, e := range entries {
if e.RPMOverride != nil && *e.RPMOverride < 0 {
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
}
}
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
......
......@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
"net/http"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
......@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
rpmSyncedGroupID int64
rpmSyncedEntries []GroupRPMOverrideInput
rpmSyncErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
......@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
......@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return s.syncGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
s.rpmSyncedGroupID = groupID
s.rpmSyncedEntries = entries
return s.rpmSyncErr
}
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
......@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
},
},
}
......@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier)
require.NotNil(t, entries[0].RateMultiplier)
require.Equal(t, 1.5, *entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier)
require.NotNil(t, entries[1].RateMultiplier)
require.Equal(t, 0.8, *entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
......@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require.Contains(t, err.Error(), "sync failed")
})
}
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
override := 20
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
require.Equal(t, entries, repo.rpmSyncedEntries)
})
t.Run("rejects negative override as bad request", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
negative := -1
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
{UserID: 2, RPMOverride: &negative},
})
require.Error(t, err)
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
require.Zero(t, repo.rpmSyncedGroupID)
})
}
......@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAnthropic,
Status: StatusActive,
RPMLimit: 10,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
groupRepo: repo,
authCacheInvalidator: invalidator,
}
rpmLimit := 60
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
RPMLimit: &rpmLimit,
})
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, 60, repo.updated.RPMLimit)
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
}
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
......
......@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
......@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic("unexpected SyncGroupRateMultipliers call")
}
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
panic("unexpected SyncGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type rpmStatusUserRepoStub struct {
UserRepository
user *User
}
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
return s.user, nil
}
type rpmStatusAPIKeyRepoStub struct {
APIKeyRepository
keys []APIKey
}
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
}
type rpmStatusGroupRepoStub struct {
GroupRepository
groups map[int64]*Group
}
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
return s.groups[id], nil
}
type rpmStatusRateRepoStub struct {
UserGroupRateRepository
overrides map[int64]*int
}
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
return s.overrides[groupID], nil
}
type rpmStatusCacheStub struct {
UserRPMCache
userUsed int
groupUsed map[int64]int
}
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
return s.groupUsed[groupID], nil
}
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
return s.userUsed, nil
}
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
groupOneID := int64(1)
groupTwoID := int64(2)
override := 7
svc := &adminServiceImpl{
userRepo: &rpmStatusUserRepoStub{user: &User{
ID: 42,
RPMLimit: 20,
}},
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
{ID: 100, UserID: 42, GroupID: &groupTwoID},
{ID: 101, UserID: 42, GroupID: &groupOneID},
{ID: 102, UserID: 42, GroupID: &groupTwoID},
{ID: 103, UserID: 42},
}},
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
}},
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
groupTwoID: &override,
}},
userRPMCache: &rpmStatusCacheStub{
userUsed: 5,
groupUsed: map[int64]int{
groupOneID: 3,
groupTwoID: 4,
},
},
}
status, err := svc.GetUserRPMStatus(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, &UserRPMStatus{
UserRPMUsed: 5,
UserRPMLimit: 20,
PerGroup: []UserGroupRPMStatus{
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
},
}, status)
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
type rpmUserRepoStub struct {
*userRepoStub
lastUpdated *User
}
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
s.lastUpdated = &clone
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newRPM := 60
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
RPMLimit: &newRPM,
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 60, updated.RPMLimit)
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
}
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newName := "new"
sameRPM := 10
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
Username: &newName,
RPMLimit: &sameRPM,
})
require.NoError(t, err)
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
}
......@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
RPMLimit int `json:"rpm_limit"`
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
......@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
RPMLimit int `json:"rpm_limit"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
......@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
type apiKeyAuthCacheConfig struct {
l1Size int
......@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
......@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
......@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
RPMLimit: apiKey.User.RPMLimit,
},
}
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
if err == nil && override != nil {
snapshot.User.UserGroupRPMOverride = override
}
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
......@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
RPMLimit: apiKey.Group.RPMLimit,
}
}
return snapshot
......@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
RPMLimit: snapshot.User.RPMLimit,
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
},
}
if snapshot.Group != nil {
......@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
RPMLimit: snapshot.Group.RPMLimit,
}
}
s.compileAPIKeyIPRules(apiKey)
......
......@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
......
......@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户
user := &User{
Email: email,
......@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
}
......@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
Email: email,
......@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
......@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
Email: email,
......@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
......
......@@ -20,6 +20,9 @@ import (
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
......@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader
userRPMCache UserRPMCache
userGroupRateRepo UserGroupRateRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
......@@ -104,12 +109,22 @@ type BillingCacheService struct {
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
func NewBillingCacheService(
cache BillingCache,
userRepo UserRepository,
subRepo UserSubscriptionRepository,
apiKeyRepo APIKeyRepository,
userRPMCache UserRPMCache,
userGroupRateRepo UserGroupRateRepository,
cfg *config.Config,
) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo,
userRPMCache: userRPMCache,
userGroupRateRepo: userGroupRateRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
......@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
}
}
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
if err := s.checkRPM(ctx, user, group); err != nil {
return err
}
return nil
}
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
//
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
//
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
if s == nil || s.userRPMCache == nil || user == nil {
return nil
}
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
if group != nil {
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
var override *int
if user.UserGroupRPMOverride != nil {
override = user.UserGroupRPMOverride
} else if s.userGroupRateRepo != nil {
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm override lookup failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
} else {
override = dbOverride
}
}
if override != nil {
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
if *override > 0 {
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if incErr != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
user.ID, group.ID, incErr,
)
// fail-open
} else if count > *override {
return ErrGroupRPMExceeded
}
}
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
} else if group.RPMLimit > 0 {
// 无 override,检查 group.rpm_limit。
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
// fail-open
} else if count > group.RPMLimit {
return ErrGroupRPMExceeded
}
}
}
// ── 第二层:用户级全局硬上限(始终生效) ──
if user.RPMLimit > 0 {
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (user) failed for user=%d: %v",
user.ID, err,
)
return nil // fail-open
}
if count > user.RPMLimit {
return ErrUserRPMExceeded
}
}
return nil
}
......
//go:build unit
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
type userRPMCacheStub struct {
userGroupCalls int32
userCalls int32
userGroupCounts []int // 依次返回的计数值
userGroupErr error
userCounts []int
userErr error
}
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
if s.userGroupErr != nil {
return 0, s.userGroupErr
}
if idx < len(s.userGroupCounts) {
return s.userGroupCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
if s.userErr != nil {
return 0, s.userErr
}
if idx < len(s.userCounts) {
return s.userCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
return 0, nil
}
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
return 0, nil
}
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
type rpmOverrideRepoStub struct {
UserGroupRateRepository
override *int
err error
calls int32
}
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
atomic.AddInt32(&s.calls, 1)
if s.err != nil {
return nil, s.err
}
return s.override, nil
}
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
t.Helper()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
t.Cleanup(svc.Stop)
return svc
}
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
override := 2
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
override := 100 // override 很高
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
}
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
zero := 0
// user 计数: 依次返回 1..6
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 5}
group := &Group{ID: 10, RPMLimit: 100}
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
for i := 0; i < 5; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
}
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
"override=0 跳过分组但 user 全局上限仍应生效")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
}
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
zero := 0
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0} // user 也不限
group := &Group{ID: 10, RPMLimit: 100}
for i := 0; i < 50; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
}
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
group := &Group{ID: 10, RPMLimit: 5}
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
}
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 10}
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 0}
for i := 0; i < 10; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 5}
// Redis 故障时应 fail-open,不拒绝请求
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
}
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
// 无 group(纯用户级限流场景),不应查询 rpm_override。
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
}
......@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
......
......@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
......@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
......
......@@ -173,6 +173,7 @@ const (
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
......
......@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
RPMLimit int
CreatedAt time.Time
UpdatedAt time.Time
......
......@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
......@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance
}
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
if err != nil || value == "" {
return 0
}
if v, err := strconv.Atoi(value); err == nil && v >= 0 {
return v
}
return 0
}
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
......@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
......@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
result.DefaultUserRPMLimit = rpm
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment