Commit dc5d42ad authored by james-6-23's avatar james-6-23
Browse files

feat(rpm): RPM 限流模块优化

P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
parent ef967d8f
...@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet( ...@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet(
ProvideConcurrencyCache, ProvideConcurrencyCache,
ProvideSessionLimitCache, ProvideSessionLimitCache,
NewRPMCache, NewRPMCache,
NewUserRPMCache,
NewUserMsgQueueCache, NewUserMsgQueueCache,
NewDashboardCache, NewDashboardCache,
NewEmailCache, NewEmailCache,
......
...@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
"role": "user", "role": "user",
"balance": 12.5, "balance": 12.5,
"concurrency": 5, "concurrency": 5,
"rpm_limit": 0,
"status": "active", "status": "active",
"allowed_groups": null, "allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
...@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_group_id_on_invalid_request": null, "fallback_group_id_on_invalid_request": null,
"require_oauth_only": false, "require_oauth_only": false,
"require_privacy_set": false, "require_privacy_set": false,
"rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false, "force_email_on_third_party_signup": false,
"default_concurrency": 5, "default_concurrency": 5,
"default_balance": 1.25, "default_balance": 1.25,
"default_user_rpm_limit": 0,
"default_subscriptions": [], "default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...@@ -889,6 +892,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -889,6 +892,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [], "custom_endpoints": [],
"default_concurrency": 0, "default_concurrency": 0,
"default_balance": 0, "default_balance": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [], "default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...@@ -1084,7 +1088,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -1084,7 +1088,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......
...@@ -221,6 +221,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -221,6 +221,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
// User attribute values // User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
...@@ -244,6 +245,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -244,6 +245,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers) groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers) groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
} }
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"sort"
"strings" "strings"
"time" "time"
...@@ -32,6 +33,7 @@ type AdminService interface { ...@@ -32,6 +33,7 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types. // codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups). // Also returns totalRecharged (sum of all positive balance top-ups).
...@@ -50,6 +52,8 @@ type AdminService interface { ...@@ -50,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin) // API Key management (admin)
...@@ -114,6 +118,7 @@ type CreateUserInput struct { ...@@ -114,6 +118,7 @@ type CreateUserInput struct {
Notes string Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
RPMLimit int
AllowedGroups []int64 AllowedGroups []int64
} }
...@@ -124,6 +129,7 @@ type UpdateUserInput struct { ...@@ -124,6 +129,7 @@ type UpdateUserInput struct {
Notes *string Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0" Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置 // GroupRates 用户专属分组倍率配置
...@@ -199,6 +205,8 @@ type CreateGroupInput struct { ...@@ -199,6 +205,8 @@ type CreateGroupInput struct {
RequireOAuthOnly bool RequireOAuthOnly bool
RequirePrivacySet bool RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制)
RPMLimit int
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -234,6 +242,8 @@ type UpdateGroupInput struct { ...@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly *bool RequireOAuthOnly *bool
RequirePrivacySet *bool RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
RPMLimit *int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct { ...@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys int64 // 迁移的 Key 数量 MigratedKeys int64 // 迁移的 Key 数量
} }
// UserRPMStatus describes a user's current per-minute RPM usage.
type UserRPMStatus struct {
UserRPMUsed int `json:"user_rpm_used"`
UserRPMLimit int `json:"user_rpm_limit"`
PerGroup []UserGroupRPMStatus `json:"per_group"`
}
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
type UserGroupRPMStatus struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Used int `json:"used"`
Limit int `json:"limit"`
Source string `json:"source"` // "group" | "override"
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates. // BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct { type BulkUpdateAccountsResult struct {
Success int `json:"success"` Success int `json:"success"`
...@@ -463,6 +489,8 @@ const ( ...@@ -463,6 +489,8 @@ const (
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
) )
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo UserRepository userRepo UserRepository
...@@ -472,6 +500,7 @@ type adminServiceImpl struct { ...@@ -472,6 +500,7 @@ type adminServiceImpl struct {
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository userGroupRateRepo UserGroupRateRepository
userRPMCache UserRPMCache
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
...@@ -496,6 +525,7 @@ func NewAdminService( ...@@ -496,6 +525,7 @@ func NewAdminService(
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository, userGroupRateRepo UserGroupRateRepository,
userRPMCache UserRPMCache,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
...@@ -514,6 +544,7 @@ func NewAdminService( ...@@ -514,6 +544,7 @@ func NewAdminService(
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo, userGroupRateRepo: userGroupRateRepo,
userRPMCache: userRPMCache,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
...@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu ...@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role: RoleUser, // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
RPMLimit: input.RPMLimit,
Status: StatusActive, Status: StatusActive,
AllowedGroups: input.AllowedGroups, AllowedGroups: input.AllowedGroups,
} }
...@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency oldConcurrency := user.Concurrency
oldStatus := user.Status oldStatus := user.Status
oldRole := user.Role oldRole := user.Role
oldRPMLimit := user.RPMLimit
if input.Email != "" { if input.Email != "" {
user.Email = input.Email user.Email = input.Email
...@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Concurrency = *input.Concurrency user.Concurrency = *input.Concurrency
} }
if input.RPMLimit != nil {
user.RPMLimit = *input.RPMLimit
}
if input.AllowedGroups != nil { if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups user.AllowedGroups = *input.AllowedGroups
} }
...@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
} }
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
} }
} }
...@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag ...@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil return keys, result.Total, nil
} }
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
if s.userRPMCache == nil {
return nil, ErrRPMStatusUnavailable
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
}
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
if err != nil {
return nil, err
}
groupIDSet := make(map[int64]struct{})
for _, key := range keys {
if key.GroupID != nil && *key.GroupID > 0 {
groupIDSet[*key.GroupID] = struct{}{}
}
}
groupIDs := make([]int64, 0, len(groupIDSet))
for groupID := range groupIDSet {
groupIDs = append(groupIDs, groupID)
}
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
var perGroup []UserGroupRPMStatus
for _, groupID := range groupIDs {
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
if getErr != nil {
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
}
entry := UserGroupRPMStatus{
GroupID: groupID,
Used: used,
}
if s.groupRepo != nil {
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
entry.GroupName = group.Name
entry.Limit = group.RPMLimit
entry.Source = "group"
} else if groupErr != nil {
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
}
}
if s.userGroupRateRepo != nil {
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
if overrideErr != nil {
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
} else if override != nil {
entry.Limit = *override
entry.Source = "override"
}
}
perGroup = append(perGroup, entry)
}
return &UserRPMStatus{
UserRPMUsed: userRPMUsed,
UserRPMLimit: user.RPMLimit,
PerGroup: perGroup,
}, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now // Return mock data for now
return map[string]any{ return map[string]any{
...@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet: input.RequirePrivacySet, RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel, DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
RPMLimit: input.RPMLimit,
} }
sanitizeGroupMessagesDispatchFields(group) sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
...@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MessagesDispatchModelConfig != nil { if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
} }
if input.RPMLimit != nil {
group.RPMLimit = *input.RPMLimit
}
sanitizeGroupMessagesDispatchFields(group) sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 { if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs // 去重源分组 IDs
...@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
} }
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil return group, nil
} }
...@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro ...@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
} }
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
if s.userGroupRateRepo == nil {
return nil
}
for _, e := range entries {
if e.RPMOverride != nil && *e.RPMOverride < 0 {
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
}
}
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates) return s.groupRepo.UpdateSortOrders(ctx, updates)
} }
......
...@@ -5,8 +5,10 @@ package service ...@@ -5,8 +5,10 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"net/http"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct { ...@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID int64 syncedGroupID int64
syncedEntries []GroupRateMultiplierInput syncedEntries []GroupRateMultiplierInput
syncGroupErr error syncGroupErr error
rpmSyncedGroupID int64
rpmSyncedEntries []GroupRPMOverrideInput
rpmSyncErr error
} }
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) { func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
...@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, ...@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call") panic("unexpected GetByUserAndGroup call")
} }
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil { if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr return nil, s.getByGroupIDErr
...@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C ...@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return s.syncGroupErr return s.syncGroupErr
} }
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
s.rpmSyncedGroupID = groupID
s.rpmSyncedEntries = entries
return s.rpmSyncErr
}
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error { func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID) s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr return s.deleteByGroupErr
...@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { ...@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{ repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{ getByGroupIDData: map[int64][]UserGroupRateEntry{
10: { 10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5}, {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8}, {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
}, },
}, },
} }
...@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { ...@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require.Len(t, entries, 2) require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID) require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName) require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier) require.NotNil(t, entries[0].RateMultiplier)
require.Equal(t, 1.5, *entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID) require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier) require.NotNil(t, entries[1].RateMultiplier)
require.Equal(t, 0.8, *entries[1].RateMultiplier)
}) })
t.Run("returns nil when repo is nil", func(t *testing.T) { t.Run("returns nil when repo is nil", func(t *testing.T) {
...@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) { ...@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require.Contains(t, err.Error(), "sync failed") require.Contains(t, err.Error(), "sync failed")
}) })
} }
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
override := 20
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
require.Equal(t, entries, repo.rpmSyncedEntries)
})
t.Run("rejects negative override as bad request", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
negative := -1
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
{UserID: 2, RPMOverride: &negative},
})
require.Error(t, err)
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
require.Zero(t, repo.rpmSyncedGroupID)
})
}
...@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAnthropic,
Status: StatusActive,
RPMLimit: 10,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
groupRepo: repo,
authCacheInvalidator: invalidator,
}
rpmLimit := 60
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
RPMLimit: &rpmLimit,
})
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, 60, repo.updated.RPMLimit)
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
}
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
......
...@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, ...@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call") panic("unexpected GetByUserAndGroup call")
} }
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call") panic("unexpected SyncUserGroupRates call")
} }
...@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C ...@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic("unexpected SyncGroupRateMultipliers call") panic("unexpected SyncGroupRateMultipliers call")
} }
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
panic("unexpected SyncGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error { func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call") panic("unexpected DeleteByGroupID call")
} }
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type rpmStatusUserRepoStub struct {
UserRepository
user *User
}
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
return s.user, nil
}
type rpmStatusAPIKeyRepoStub struct {
APIKeyRepository
keys []APIKey
}
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
}
type rpmStatusGroupRepoStub struct {
GroupRepository
groups map[int64]*Group
}
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
return s.groups[id], nil
}
type rpmStatusRateRepoStub struct {
UserGroupRateRepository
overrides map[int64]*int
}
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
return s.overrides[groupID], nil
}
type rpmStatusCacheStub struct {
UserRPMCache
userUsed int
groupUsed map[int64]int
}
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
return s.groupUsed[groupID], nil
}
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
return s.userUsed, nil
}
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
groupOneID := int64(1)
groupTwoID := int64(2)
override := 7
svc := &adminServiceImpl{
userRepo: &rpmStatusUserRepoStub{user: &User{
ID: 42,
RPMLimit: 20,
}},
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
{ID: 100, UserID: 42, GroupID: &groupTwoID},
{ID: 101, UserID: 42, GroupID: &groupOneID},
{ID: 102, UserID: 42, GroupID: &groupTwoID},
{ID: 103, UserID: 42},
}},
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
}},
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
groupTwoID: &override,
}},
userRPMCache: &rpmStatusCacheStub{
userUsed: 5,
groupUsed: map[int64]int{
groupOneID: 3,
groupTwoID: 4,
},
},
}
status, err := svc.GetUserRPMStatus(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, &UserRPMStatus{
UserRPMUsed: 5,
UserRPMLimit: 20,
PerGroup: []UserGroupRPMStatus{
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
},
}, status)
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
type rpmUserRepoStub struct {
*userRepoStub
lastUpdated *User
}
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
s.lastUpdated = &clone
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newRPM := 60
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
RPMLimit: &newRPM,
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 60, updated.RPMLimit)
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
}
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newName := "new"
sameRPM := 10
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
Username: &newName,
RPMLimit: &sameRPM,
})
require.NoError(t, err)
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
}
...@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct { ...@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"` TotalRecharged float64 `json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
RPMLimit int `json:"rpm_limit"`
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
} }
// APIKeyAuthGroupSnapshot 分组快照 // APIKeyAuthGroupSnapshot 分组快照
...@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
RPMLimit int `json:"rpm_limit"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
...@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st ...@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
apiKey.Key = key apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey) snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil { if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
} }
...@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn ...@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
} }
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil { if apiKey == nil || apiKey.User == nil {
return nil return nil
} }
...@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged, TotalRecharged: apiKey.User.TotalRecharged,
RPMLimit: apiKey.User.RPMLimit,
}, },
} }
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
if err == nil && override != nil {
snapshot.User.UserGroupRPMOverride = override
}
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
}
if apiKey.Group != nil { if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{ snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID, ID: apiKey.Group.ID,
...@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel, DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
RPMLimit: apiKey.Group.RPMLimit,
} }
} }
return snapshot return snapshot
...@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged, TotalRecharged: snapshot.User.TotalRecharged,
RPMLimit: snapshot.User.RPMLimit,
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
}, },
} }
if snapshot.Group != nil { if snapshot.Group != nil {
...@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel, DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
RPMLimit: snapshot.Group.RPMLimit,
} }
} }
s.compileAPIKeyIPRules(apiKey) s.compileAPIKeyIPRules(apiKey)
......
...@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t ...@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
}, },
} }
snapshot := svc.snapshotFromAPIKey(apiKey) snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip) require.NotNil(t, roundTrip)
......
...@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
grantPlan := s.resolveSignupGrantPlan(ctx, "email") grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户 // 创建用户
user := &User{ user := &User{
Email: email, Email: email,
...@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Role: RoleUser, Role: RoleUser,
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive, Status: StatusActive,
} }
...@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
signupSource := inferLegacySignupSource(email) signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{ newUser := &User{
Email: email, Email: email,
...@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Role: RoleUser, Role: RoleUser,
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource, SignupSource: signupSource,
} }
...@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
signupSource := inferLegacySignupSource(email) signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{ newUser := &User{
Email: email, Email: email,
...@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Role: RoleUser, Role: RoleUser,
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource, SignupSource: signupSource,
} }
......
...@@ -20,6 +20,9 @@ import ( ...@@ -20,6 +20,9 @@ import (
var ( var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
...@@ -87,6 +90,8 @@ type BillingCacheService struct { ...@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader apiKeyRateLimitLoader apiKeyRateLimitLoader
userRPMCache UserRPMCache
userGroupRateRepo UserGroupRateRepository
cfg *config.Config cfg *config.Config
circuitBreaker *billingCircuitBreaker circuitBreaker *billingCircuitBreaker
...@@ -104,12 +109,22 @@ type BillingCacheService struct { ...@@ -104,12 +109,22 @@ type BillingCacheService struct {
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { func NewBillingCacheService(
cache BillingCache,
userRepo UserRepository,
subRepo UserSubscriptionRepository,
apiKeyRepo APIKeyRepository,
userRPMCache UserRPMCache,
userGroupRateRepo UserGroupRateRepository,
cfg *config.Config,
) *BillingCacheService {
svc := &BillingCacheService{ svc := &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo, apiKeyRateLimitLoader: apiKeyRepo,
userRPMCache: userRPMCache,
userGroupRateRepo: userGroupRateRepo,
cfg: cfg, cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
...@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
} }
} }
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
if err := s.checkRPM(ctx, user, group); err != nil {
return err
}
return nil
}
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
//
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
//
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
if s == nil || s.userRPMCache == nil || user == nil {
return nil
}
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
if group != nil {
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
var override *int
if user.UserGroupRPMOverride != nil {
override = user.UserGroupRPMOverride
} else if s.userGroupRateRepo != nil {
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm override lookup failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
} else {
override = dbOverride
}
}
if override != nil {
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
if *override > 0 {
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if incErr != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
user.ID, group.ID, incErr,
)
// fail-open
} else if count > *override {
return ErrGroupRPMExceeded
}
}
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
} else if group.RPMLimit > 0 {
// 无 override,检查 group.rpm_limit。
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
// fail-open
} else if count > group.RPMLimit {
return ErrGroupRPMExceeded
}
}
}
// ── 第二层:用户级全局硬上限(始终生效) ──
if user.RPMLimit > 0 {
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (user) failed for user=%d: %v",
user.ID, err,
)
return nil // fail-open
}
if count > user.RPMLimit {
return ErrUserRPMExceeded
}
}
return nil return nil
} }
......
//go:build unit
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
type userRPMCacheStub struct {
userGroupCalls int32
userCalls int32
userGroupCounts []int // 依次返回的计数值
userGroupErr error
userCounts []int
userErr error
}
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
if s.userGroupErr != nil {
return 0, s.userGroupErr
}
if idx < len(s.userGroupCounts) {
return s.userGroupCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
if s.userErr != nil {
return 0, s.userErr
}
if idx < len(s.userCounts) {
return s.userCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
return 0, nil
}
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
return 0, nil
}
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
type rpmOverrideRepoStub struct {
UserGroupRateRepository
override *int
err error
calls int32
}
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
atomic.AddInt32(&s.calls, 1)
if s.err != nil {
return nil, s.err
}
return s.override, nil
}
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
t.Helper()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
t.Cleanup(svc.Stop)
return svc
}
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
override := 2
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
override := 100 // override 很高
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
}
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
zero := 0
// user 计数: 依次返回 1..6
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 5}
group := &Group{ID: 10, RPMLimit: 100}
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
for i := 0; i < 5; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
}
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
"override=0 跳过分组但 user 全局上限仍应生效")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
}
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
zero := 0
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0} // user 也不限
group := &Group{ID: 10, RPMLimit: 100}
for i := 0; i < 50; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
}
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
group := &Group{ID: 10, RPMLimit: 5}
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
}
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 10}
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 0}
for i := 0; i < 10; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 5}
// Redis 故障时应 fail-open,不拒绝请求
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
}
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
// 无 group(纯用户级限流场景),不应查询 rpm_override。
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
}
...@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond, delay: 80 * time.Millisecond,
balance: 12.34, balance: 12.34,
} }
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop) t.Cleanup(svc.Stop)
const goroutines = 16 const goroutines = 16
......
...@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, ...@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{} cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop) t.Cleanup(svc.Stop)
start := time.Now() start := time.Now()
...@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { ...@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{} cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc.Stop() svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{ enqueued := svc.enqueueCacheWrite(cacheWriteTask{
......
...@@ -170,9 +170,10 @@ const ( ...@@ -170,9 +170,10 @@ const (
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
// 第三方认证来源默认授予配置 // 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
......
...@@ -59,6 +59,10 @@ type Group struct { ...@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel string DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
RPMLimit int
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
......
...@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting ...@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal default subscriptions: %w", err) return nil, fmt.Errorf("marshal default subscriptions: %w", err)
...@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { ...@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance return s.cfg.Default.UserBalance
} }
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
if err != nil || value == "" {
return 0
}
if v, err := strconv.Atoi(value); err == nil && v >= 0 {
return v
}
return 0
}
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。 // GetDefaultSubscriptions 获取新用户默认订阅配置列表。
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
...@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "", SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]", SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0", SettingKeyAuthSourceDefaultEmailBalance: "0",
SettingKeyAuthSourceDefaultEmailConcurrency: "5", SettingKeyAuthSourceDefaultEmailConcurrency: "5",
...@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultConcurrency = s.cfg.Default.UserConcurrency result.DefaultConcurrency = s.cfg.Default.UserConcurrency
} }
if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
result.DefaultUserRPMLimit = rpm
}
// 解析浮点数类型 // 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance result.DefaultBalance = balance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment