Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6b0cf466
Unverified
Commit
6b0cf466
authored
Apr 23, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 23, 2026
Browse files
Merge pull request #1815 from james-6-23/feat_rpm
feat(rpm): RPM 限流模块优化
parents
ef967d8f
dc5d42ad
Changes
79
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/wire.go
View file @
6b0cf466
...
...
@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet(
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewRPMCache
,
NewUserRPMCache
,
NewUserMsgQueueCache
,
NewDashboardCache
,
NewEmailCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
6b0cf466
...
...
@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
"role": "user",
"balance": 12.5,
"concurrency": 5,
"rpm_limit": 0,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
...
...
@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_group_id_on_invalid_request": null,
"require_oauth_only": false,
"require_privacy_set": false,
"rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...
...
@@ -889,6 +892,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
...
...
@@ -1084,7 +1088,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
redeemService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
...
...
backend/internal/server/routes/admin.go
View file @
6b0cf466
...
...
@@ -221,6 +221,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users
.
GET
(
"/:id/usage"
,
h
.
Admin
.
User
.
GetUserUsage
)
users
.
GET
(
"/:id/balance-history"
,
h
.
Admin
.
User
.
GetBalanceHistory
)
users
.
POST
(
"/:id/replace-group"
,
h
.
Admin
.
User
.
ReplaceGroup
)
users
.
GET
(
"/:id/rpm-status"
,
h
.
Admin
.
User
.
GetUserRPMStatus
)
// User attribute values
users
.
GET
(
"/:id/attributes"
,
h
.
Admin
.
UserAttribute
.
GetUserAttributes
)
...
...
@@ -244,6 +245,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups
.
GET
(
"/:id/rate-multipliers"
,
h
.
Admin
.
Group
.
GetGroupRateMultipliers
)
groups
.
PUT
(
"/:id/rate-multipliers"
,
h
.
Admin
.
Group
.
BatchSetGroupRateMultipliers
)
groups
.
DELETE
(
"/:id/rate-multipliers"
,
h
.
Admin
.
Group
.
ClearGroupRateMultipliers
)
groups
.
PUT
(
"/:id/rpm-overrides"
,
h
.
Admin
.
Group
.
BatchSetGroupRPMOverrides
)
groups
.
DELETE
(
"/:id/rpm-overrides"
,
h
.
Admin
.
Group
.
ClearGroupRPMOverrides
)
groups
.
GET
(
"/:id/api-keys"
,
h
.
Admin
.
Group
.
GetGroupAPIKeys
)
}
}
...
...
backend/internal/service/admin_service.go
View file @
6b0cf466
...
...
@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"sort"
"strings"
"time"
...
...
@@ -32,6 +33,7 @@ type AdminService interface {
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
,
notes
string
)
(
*
User
,
error
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
,
sortBy
,
sortOrder
string
)
([]
APIKey
,
int64
,
error
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
GetUserRPMStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserRPMStatus
,
error
)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
...
...
@@ -50,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
)
([]
UserGroupRateEntry
,
error
)
ClearGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
)
error
BatchSetGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
GroupRateMultiplierInput
)
error
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
BatchSetGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
GroupRPMOverrideInput
)
error
UpdateGroupSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
// API Key management (admin)
...
...
@@ -114,6 +118,7 @@ type CreateUserInput struct {
Notes
string
Balance
float64
Concurrency
int
RPMLimit
int
AllowedGroups
[]
int64
}
...
...
@@ -124,6 +129,7 @@ type UpdateUserInput struct {
Notes
*
string
Balance
*
float64
// 使用指针区分"未提供"和"设置为0"
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
RPMLimit
*
int
// 使用指针区分"未提供"和"设置为0"
Status
string
AllowedGroups
*
[]
int64
// 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
...
...
@@ -199,6 +205,8 @@ type CreateGroupInput struct {
RequireOAuthOnly
bool
RequirePrivacySet
bool
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制)
RPMLimit
int
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly
*
bool
RequirePrivacySet
*
bool
MessagesDispatchModelConfig
*
OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
RPMLimit
*
int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys
int64
// 迁移的 Key 数量
}
// UserRPMStatus describes a user's current per-minute RPM usage.
type
UserRPMStatus
struct
{
UserRPMUsed
int
`json:"user_rpm_used"`
UserRPMLimit
int
`json:"user_rpm_limit"`
PerGroup
[]
UserGroupRPMStatus
`json:"per_group"`
}
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
type
UserGroupRPMStatus
struct
{
GroupID
int64
`json:"group_id"`
GroupName
string
`json:"group_name"`
Used
int
`json:"used"`
Limit
int
`json:"limit"`
Source
string
`json:"source"`
// "group" | "override"
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type
BulkUpdateAccountsResult
struct
{
Success
int
`json:"success"`
...
...
@@ -463,6 +489,8 @@ const (
proxyQualityClientUserAgent
=
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
var
ErrRPMStatusUnavailable
=
infraerrors
.
New
(
http
.
StatusNotImplemented
,
"RPM_STATUS_UNAVAILABLE"
,
"RPM cache not available"
)
// adminServiceImpl implements AdminService
type
adminServiceImpl
struct
{
userRepo
UserRepository
...
...
@@ -472,6 +500,7 @@ type adminServiceImpl struct {
apiKeyRepo
APIKeyRepository
redeemCodeRepo
RedeemCodeRepository
userGroupRateRepo
UserGroupRateRepository
userRPMCache
UserRPMCache
billingCacheService
*
BillingCacheService
proxyProber
ProxyExitInfoProber
proxyLatencyCache
ProxyLatencyCache
...
...
@@ -496,6 +525,7 @@ func NewAdminService(
apiKeyRepo
APIKeyRepository
,
redeemCodeRepo
RedeemCodeRepository
,
userGroupRateRepo
UserGroupRateRepository
,
userRPMCache
UserRPMCache
,
billingCacheService
*
BillingCacheService
,
proxyProber
ProxyExitInfoProber
,
proxyLatencyCache
ProxyLatencyCache
,
...
...
@@ -514,6 +544,7 @@ func NewAdminService(
apiKeyRepo
:
apiKeyRepo
,
redeemCodeRepo
:
redeemCodeRepo
,
userGroupRateRepo
:
userGroupRateRepo
,
userRPMCache
:
userRPMCache
,
billingCacheService
:
billingCacheService
,
proxyProber
:
proxyProber
,
proxyLatencyCache
:
proxyLatencyCache
,
...
...
@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role
:
RoleUser
,
// Always create as regular user, never admin
Balance
:
input
.
Balance
,
Concurrency
:
input
.
Concurrency
,
RPMLimit
:
input
.
RPMLimit
,
Status
:
StatusActive
,
AllowedGroups
:
input
.
AllowedGroups
,
}
...
...
@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency
:=
user
.
Concurrency
oldStatus
:=
user
.
Status
oldRole
:=
user
.
Role
oldRPMLimit
:=
user
.
RPMLimit
if
input
.
Email
!=
""
{
user
.
Email
=
input
.
Email
...
...
@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user
.
Concurrency
=
*
input
.
Concurrency
}
if
input
.
RPMLimit
!=
nil
{
user
.
RPMLimit
=
*
input
.
RPMLimit
}
if
input
.
AllowedGroups
!=
nil
{
user
.
AllowedGroups
=
*
input
.
AllowedGroups
}
...
...
@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
if
s
.
authCacheInvalidator
!=
nil
{
if
user
.
Concurrency
!=
oldConcurrency
||
user
.
Status
!=
oldStatus
||
user
.
Role
!=
oldRole
{
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
if
user
.
Concurrency
!=
oldConcurrency
||
user
.
Status
!=
oldStatus
||
user
.
Role
!=
oldRole
||
user
.
RPMLimit
!=
oldRPMLimit
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
user
.
ID
)
}
}
...
...
@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return
keys
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetUserRPMStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserRPMStatus
,
error
)
{
if
s
.
userRPMCache
==
nil
{
return
nil
,
ErrRPMStatusUnavailable
}
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
}
userRPMUsed
,
err
:=
s
.
userRPMCache
.
GetUserRPM
(
ctx
,
userID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get user rpm: user_id=%d err=%v"
,
userID
,
err
)
}
keys
,
_
,
err
:=
s
.
GetUserAPIKeys
(
ctx
,
userID
,
1
,
1000
,
""
,
""
)
if
err
!=
nil
{
return
nil
,
err
}
groupIDSet
:=
make
(
map
[
int64
]
struct
{})
for
_
,
key
:=
range
keys
{
if
key
.
GroupID
!=
nil
&&
*
key
.
GroupID
>
0
{
groupIDSet
[
*
key
.
GroupID
]
=
struct
{}{}
}
}
groupIDs
:=
make
([]
int64
,
0
,
len
(
groupIDSet
))
for
groupID
:=
range
groupIDSet
{
groupIDs
=
append
(
groupIDs
,
groupID
)
}
sort
.
Slice
(
groupIDs
,
func
(
i
,
j
int
)
bool
{
return
groupIDs
[
i
]
<
groupIDs
[
j
]
})
var
perGroup
[]
UserGroupRPMStatus
for
_
,
groupID
:=
range
groupIDs
{
used
,
getErr
:=
s
.
userRPMCache
.
GetUserGroupRPM
(
ctx
,
userID
,
groupID
)
if
getErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get user group rpm: user_id=%d group_id=%d err=%v"
,
userID
,
groupID
,
getErr
)
}
entry
:=
UserGroupRPMStatus
{
GroupID
:
groupID
,
Used
:
used
,
}
if
s
.
groupRepo
!=
nil
{
if
group
,
groupErr
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
groupID
);
groupErr
==
nil
&&
group
!=
nil
{
entry
.
GroupName
=
group
.
Name
entry
.
Limit
=
group
.
RPMLimit
entry
.
Source
=
"group"
}
else
if
groupErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get group rpm status metadata: group_id=%d err=%v"
,
groupID
,
groupErr
)
}
}
if
s
.
userGroupRateRepo
!=
nil
{
override
,
overrideErr
:=
s
.
userGroupRateRepo
.
GetRPMOverrideByUserAndGroup
(
ctx
,
userID
,
groupID
)
if
overrideErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to get rpm override: user_id=%d group_id=%d err=%v"
,
userID
,
groupID
,
overrideErr
)
}
else
if
override
!=
nil
{
entry
.
Limit
=
*
override
entry
.
Source
=
"override"
}
}
perGroup
=
append
(
perGroup
,
entry
)
}
return
&
UserRPMStatus
{
UserRPMUsed
:
userRPMUsed
,
UserRPMLimit
:
user
.
RPMLimit
,
PerGroup
:
perGroup
,
},
nil
}
func
(
s
*
adminServiceImpl
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
{
// Return mock data for now
return
map
[
string
]
any
{
...
...
@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet
:
input
.
RequirePrivacySet
,
DefaultMappedModel
:
input
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
normalizeOpenAIMessagesDispatchModelConfig
(
input
.
MessagesDispatchModelConfig
),
RPMLimit
:
input
.
RPMLimit
,
}
sanitizeGroupMessagesDispatchFields
(
group
)
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
...
...
@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
MessagesDispatchModelConfig
!=
nil
{
group
.
MessagesDispatchModelConfig
=
normalizeOpenAIMessagesDispatchModelConfig
(
*
input
.
MessagesDispatchModelConfig
)
}
if
input
.
RPMLimit
!=
nil
{
group
.
RPMLimit
=
*
input
.
RPMLimit
}
sanitizeGroupMessagesDispatchFields
(
group
)
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
...
...
@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
return
group
,
nil
}
...
...
@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
return
s
.
userGroupRateRepo
.
SyncGroupRateMultipliers
(
ctx
,
groupID
,
entries
)
}
func
(
s
*
adminServiceImpl
)
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
{
if
s
.
userGroupRateRepo
==
nil
{
return
nil
}
if
err
:=
s
.
userGroupRateRepo
.
ClearGroupRPMOverrides
(
ctx
,
groupID
);
err
!=
nil
{
return
err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
groupID
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
BatchSetGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
GroupRPMOverrideInput
)
error
{
if
s
.
userGroupRateRepo
==
nil
{
return
nil
}
for
_
,
e
:=
range
entries
{
if
e
.
RPMOverride
!=
nil
&&
*
e
.
RPMOverride
<
0
{
return
infraerrors
.
BadRequest
(
"INVALID_RPM_OVERRIDE"
,
fmt
.
Sprintf
(
"rpm_override must be >= 0 (user_id=%d)"
,
e
.
UserID
))
}
}
if
err
:=
s
.
userGroupRateRepo
.
SyncGroupRPMOverrides
(
ctx
,
groupID
,
entries
);
err
!=
nil
{
return
err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
groupID
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
UpdateGroupSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
{
return
s
.
groupRepo
.
UpdateSortOrders
(
ctx
,
updates
)
}
...
...
backend/internal/service/admin_service_group_rate_test.go
View file @
6b0cf466
...
...
@@ -5,8 +5,10 @@ package service
import
(
"context"
"errors"
"net/http"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
...
...
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID
int64
syncedEntries
[]
GroupRateMultiplierInput
syncGroupErr
error
rpmSyncedGroupID
int64
rpmSyncedEntries
[]
GroupRPMOverrideInput
rpmSyncErr
error
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
GetByUserID
(
_
context
.
Context
,
_
int64
)
(
map
[
int64
]
float64
,
error
)
{
...
...
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic
(
"unexpected GetByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
_
int64
)
(
*
int
,
error
)
{
panic
(
"unexpected GetRPMOverrideByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
GetByGroupID
(
_
context
.
Context
,
groupID
int64
)
([]
UserGroupRateEntry
,
error
)
{
if
s
.
getByGroupIDErr
!=
nil
{
return
nil
,
s
.
getByGroupIDErr
...
...
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return
s
.
syncGroupErr
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
SyncGroupRPMOverrides
(
_
context
.
Context
,
groupID
int64
,
entries
[]
GroupRPMOverrideInput
)
error
{
s
.
rpmSyncedGroupID
=
groupID
s
.
rpmSyncedEntries
=
entries
return
s
.
rpmSyncErr
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
ClearGroupRPMOverrides
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected ClearGroupRPMOverrides call"
)
}
func
(
s
*
userGroupRateRepoStubForGroupRate
)
DeleteByGroupID
(
_
context
.
Context
,
groupID
int64
)
error
{
s
.
deletedGroupIDs
=
append
(
s
.
deletedGroupIDs
,
groupID
)
return
s
.
deleteByGroupErr
...
...
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo
:=
&
userGroupRateRepoStubForGroupRate
{
getByGroupIDData
:
map
[
int64
][]
UserGroupRateEntry
{
10
:
{
{
UserID
:
1
,
UserName
:
"alice"
,
UserEmail
:
"alice@test.com"
,
RateMultiplier
:
1.5
},
{
UserID
:
2
,
UserName
:
"bob"
,
UserEmail
:
"bob@test.com"
,
RateMultiplier
:
0.8
},
{
UserID
:
1
,
UserName
:
"alice"
,
UserEmail
:
"alice@test.com"
,
RateMultiplier
:
ptrFloat
(
1.5
)
},
{
UserID
:
2
,
UserName
:
"bob"
,
UserEmail
:
"bob@test.com"
,
RateMultiplier
:
ptrFloat
(
0.8
)
},
},
},
}
...
...
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require
.
Len
(
t
,
entries
,
2
)
require
.
Equal
(
t
,
int64
(
1
),
entries
[
0
]
.
UserID
)
require
.
Equal
(
t
,
"alice"
,
entries
[
0
]
.
UserName
)
require
.
Equal
(
t
,
1.5
,
entries
[
0
]
.
RateMultiplier
)
require
.
NotNil
(
t
,
entries
[
0
]
.
RateMultiplier
)
require
.
Equal
(
t
,
1.5
,
*
entries
[
0
]
.
RateMultiplier
)
require
.
Equal
(
t
,
int64
(
2
),
entries
[
1
]
.
UserID
)
require
.
Equal
(
t
,
0.8
,
entries
[
1
]
.
RateMultiplier
)
require
.
NotNil
(
t
,
entries
[
1
]
.
RateMultiplier
)
require
.
Equal
(
t
,
0.8
,
*
entries
[
1
]
.
RateMultiplier
)
})
t
.
Run
(
"returns nil when repo is nil"
,
func
(
t
*
testing
.
T
)
{
...
...
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require
.
Contains
(
t
,
err
.
Error
(),
"sync failed"
)
})
}
func
TestAdminService_BatchSetGroupRPMOverrides
(
t
*
testing
.
T
)
{
t
.
Run
(
"syncs entries to repo"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
userGroupRateRepoStubForGroupRate
{}
svc
:=
&
adminServiceImpl
{
userGroupRateRepo
:
repo
}
override
:=
20
entries
:=
[]
GroupRPMOverrideInput
{{
UserID
:
2
,
RPMOverride
:
&
override
}}
err
:=
svc
.
BatchSetGroupRPMOverrides
(
context
.
Background
(),
10
,
entries
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
repo
.
rpmSyncedGroupID
)
require
.
Equal
(
t
,
entries
,
repo
.
rpmSyncedEntries
)
})
t
.
Run
(
"rejects negative override as bad request"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
userGroupRateRepoStubForGroupRate
{}
svc
:=
&
adminServiceImpl
{
userGroupRateRepo
:
repo
}
negative
:=
-
1
err
:=
svc
.
BatchSetGroupRPMOverrides
(
context
.
Background
(),
10
,
[]
GroupRPMOverrideInput
{
{
UserID
:
2
,
RPMOverride
:
&
negative
},
})
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
infraerrors
.
Code
(
err
))
require
.
Zero
(
t
,
repo
.
rpmSyncedGroupID
)
})
}
backend/internal/service/admin_service_group_test.go
View file @
6b0cf466
...
...
@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
}
func
TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange
(
t
*
testing
.
T
)
{
existingGroup
:=
&
Group
{
ID
:
1
,
Name
:
"existing-group"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
RPMLimit
:
10
,
}
repo
:=
&
groupRepoStubForAdmin
{
getByID
:
existingGroup
}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
,
authCacheInvalidator
:
invalidator
,
}
rpmLimit
:=
60
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
1
,
&
UpdateGroupInput
{
RPMLimit
:
&
rpmLimit
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
Equal
(
t
,
60
,
repo
.
updated
.
RPMLimit
)
require
.
Equal
(
t
,
[]
int64
{
1
},
invalidator
.
groupIDs
,
"分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存"
)
}
func
TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
...
...
backend/internal/service/admin_service_list_users_test.go
View file @
6b0cf466
...
...
@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic
(
"unexpected GetByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
_
int64
)
(
*
int
,
error
)
{
panic
(
"unexpected GetRPMOverrideByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
SyncUserGroupRates
(
_
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
panic
(
"unexpected SyncUserGroupRates call"
)
}
...
...
@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic
(
"unexpected SyncGroupRateMultipliers call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
SyncGroupRPMOverrides
(
_
context
.
Context
,
_
int64
,
_
[]
GroupRPMOverrideInput
)
error
{
panic
(
"unexpected SyncGroupRPMOverrides call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
ClearGroupRPMOverrides
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected ClearGroupRPMOverrides call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
DeleteByGroupID
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected DeleteByGroupID call"
)
}
...
...
backend/internal/service/admin_service_rpm_status_test.go
0 → 100644
View file @
6b0cf466
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type
rpmStatusUserRepoStub
struct
{
UserRepository
user
*
User
}
func
(
s
*
rpmStatusUserRepoStub
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
User
,
error
)
{
return
s
.
user
,
nil
}
type
rpmStatusAPIKeyRepoStub
struct
{
APIKeyRepository
keys
[]
APIKey
}
func
(
s
*
rpmStatusAPIKeyRepoStub
)
ListByUserID
(
_
context
.
Context
,
_
int64
,
_
pagination
.
PaginationParams
,
_
APIKeyListFilters
)
([]
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
s
.
keys
,
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
keys
))},
nil
}
type
rpmStatusGroupRepoStub
struct
{
GroupRepository
groups
map
[
int64
]
*
Group
}
func
(
s
*
rpmStatusGroupRepoStub
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
groups
[
id
],
nil
}
type
rpmStatusRateRepoStub
struct
{
UserGroupRateRepository
overrides
map
[
int64
]
*
int
}
func
(
s
*
rpmStatusRateRepoStub
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
groupID
int64
)
(
*
int
,
error
)
{
return
s
.
overrides
[
groupID
],
nil
}
type
rpmStatusCacheStub
struct
{
UserRPMCache
userUsed
int
groupUsed
map
[
int64
]
int
}
func
(
s
*
rpmStatusCacheStub
)
IncrementUserGroupRPM
(
context
.
Context
,
int64
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
rpmStatusCacheStub
)
IncrementUserRPM
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
rpmStatusCacheStub
)
GetUserGroupRPM
(
_
context
.
Context
,
_
,
groupID
int64
)
(
int
,
error
)
{
return
s
.
groupUsed
[
groupID
],
nil
}
func
(
s
*
rpmStatusCacheStub
)
GetUserRPM
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
s
.
userUsed
,
nil
}
func
TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits
(
t
*
testing
.
T
)
{
groupOneID
:=
int64
(
1
)
groupTwoID
:=
int64
(
2
)
override
:=
7
svc
:=
&
adminServiceImpl
{
userRepo
:
&
rpmStatusUserRepoStub
{
user
:
&
User
{
ID
:
42
,
RPMLimit
:
20
,
}},
apiKeyRepo
:
&
rpmStatusAPIKeyRepoStub
{
keys
:
[]
APIKey
{
{
ID
:
100
,
UserID
:
42
,
GroupID
:
&
groupTwoID
},
{
ID
:
101
,
UserID
:
42
,
GroupID
:
&
groupOneID
},
{
ID
:
102
,
UserID
:
42
,
GroupID
:
&
groupTwoID
},
{
ID
:
103
,
UserID
:
42
},
}},
groupRepo
:
&
rpmStatusGroupRepoStub
{
groups
:
map
[
int64
]
*
Group
{
groupOneID
:
{
ID
:
groupOneID
,
Name
:
"group-one"
,
RPMLimit
:
10
},
groupTwoID
:
{
ID
:
groupTwoID
,
Name
:
"group-two"
,
RPMLimit
:
60
},
}},
userGroupRateRepo
:
&
rpmStatusRateRepoStub
{
overrides
:
map
[
int64
]
*
int
{
groupTwoID
:
&
override
,
}},
userRPMCache
:
&
rpmStatusCacheStub
{
userUsed
:
5
,
groupUsed
:
map
[
int64
]
int
{
groupOneID
:
3
,
groupTwoID
:
4
,
},
},
}
status
,
err
:=
svc
.
GetUserRPMStatus
(
context
.
Background
(),
42
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
&
UserRPMStatus
{
UserRPMUsed
:
5
,
UserRPMLimit
:
20
,
PerGroup
:
[]
UserGroupRPMStatus
{
{
GroupID
:
groupOneID
,
GroupName
:
"group-one"
,
Used
:
3
,
Limit
:
10
,
Source
:
"group"
},
{
GroupID
:
groupTwoID
,
GroupName
:
"group-two"
,
Used
:
4
,
Limit
:
7
,
Source
:
"override"
},
},
},
status
)
}
backend/internal/service/admin_service_update_user_rpm_test.go
0 → 100644
View file @
6b0cf466
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
type
rpmUserRepoStub
struct
{
*
userRepoStub
lastUpdated
*
User
}
func
(
s
*
rpmUserRepoStub
)
Update
(
_
context
.
Context
,
user
*
User
)
error
{
if
user
==
nil
{
return
nil
}
clone
:=
*
user
s
.
lastUpdated
=
&
clone
if
s
.
userRepoStub
!=
nil
{
s
.
userRepoStub
.
user
=
&
clone
}
return
nil
}
func
TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange
(
t
*
testing
.
T
)
{
base
:=
&
userRepoStub
{
user
:
&
User
{
ID
:
42
,
Email
:
"u@example.com"
,
RPMLimit
:
10
}}
repo
:=
&
rpmUserRepoStub
{
userRepoStub
:
base
}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
userRepo
:
repo
,
redeemCodeRepo
:
&
redeemRepoStub
{},
authCacheInvalidator
:
invalidator
,
}
newRPM
:=
60
updated
,
err
:=
svc
.
UpdateUser
(
context
.
Background
(),
42
,
&
UpdateUserInput
{
RPMLimit
:
&
newRPM
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
updated
)
require
.
Equal
(
t
,
60
,
updated
.
RPMLimit
)
require
.
Equal
(
t
,
[]
int64
{
42
},
invalidator
.
userIDs
,
"仅修改 RPMLimit 也应失效 API Key 认证缓存"
)
}
func
TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged
(
t
*
testing
.
T
)
{
base
:=
&
userRepoStub
{
user
:
&
User
{
ID
:
42
,
Email
:
"u@example.com"
,
RPMLimit
:
10
,
Username
:
"old"
}}
repo
:=
&
rpmUserRepoStub
{
userRepoStub
:
base
}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
userRepo
:
repo
,
redeemCodeRepo
:
&
redeemRepoStub
{},
authCacheInvalidator
:
invalidator
,
}
newName
:=
"new"
sameRPM
:=
10
_
,
err
:=
svc
.
UpdateUser
(
context
.
Background
(),
42
,
&
UpdateUserInput
{
Username
:
&
newName
,
RPMLimit
:
&
sameRPM
,
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
invalidator
.
userIDs
,
"只改 username 不应触发认证缓存失效"
)
}
backend/internal/service/api_key_auth_cache.go
View file @
6b0cf466
...
...
@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold
*
float64
`json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails,omitempty"`
TotalRecharged
float64
`json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
RPMLimit
int
`json:"rpm_limit"`
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
UserGroupRPMOverride
*
int
`json:"user_group_rpm_override,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
...
...
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch
bool
`json:"allow_messages_dispatch"`
DefaultMappedModel
string
`json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config,omitempty"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
RPMLimit
int
`json:"rpm_limit"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
6b0cf466
...
...
@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const
apiKeyAuthSnapshotVersion
=
5
// v
5
: added
TotalRecharged for percentage thre
sho
ld
const
apiKeyAuthSnapshotVersion
=
7
// v
7
: added
UserGroupRPMOverride on user snap
sho
t
type
apiKeyAuthCacheConfig
struct
{
l1Size
int
...
...
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
apiKey
.
Key
=
key
snapshot
:=
s
.
snapshotFromAPIKey
(
apiKey
)
snapshot
:=
s
.
snapshotFromAPIKey
(
ctx
,
apiKey
)
if
snapshot
==
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
ErrAPIKeyNotFound
)
}
...
...
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return
s
.
snapshotToAPIKey
(
key
,
entry
.
Snapshot
),
true
,
nil
}
func
(
s
*
APIKeyService
)
snapshotFromAPIKey
(
apiKey
*
APIKey
)
*
APIKeyAuthSnapshot
{
func
(
s
*
APIKeyService
)
snapshotFromAPIKey
(
ctx
context
.
Context
,
apiKey
*
APIKey
)
*
APIKeyAuthSnapshot
{
if
apiKey
==
nil
||
apiKey
.
User
==
nil
{
return
nil
}
...
...
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold
:
apiKey
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
apiKey
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
apiKey
.
User
.
TotalRecharged
,
RPMLimit
:
apiKey
.
User
.
RPMLimit
,
},
}
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
if
apiKey
.
GroupID
!=
nil
&&
*
apiKey
.
GroupID
>
0
&&
s
.
userGroupRateRepo
!=
nil
{
override
,
err
:=
s
.
userGroupRateRepo
.
GetRPMOverrideByUserAndGroup
(
ctx
,
apiKey
.
UserID
,
*
apiKey
.
GroupID
)
if
err
==
nil
&&
override
!=
nil
{
snapshot
.
User
.
UserGroupRPMOverride
=
override
}
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
}
if
apiKey
.
Group
!=
nil
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
ID
:
apiKey
.
Group
.
ID
,
...
...
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch
:
apiKey
.
Group
.
AllowMessagesDispatch
,
DefaultMappedModel
:
apiKey
.
Group
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
apiKey
.
Group
.
MessagesDispatchModelConfig
,
RPMLimit
:
apiKey
.
Group
.
RPMLimit
,
}
}
return
snapshot
...
...
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold
:
snapshot
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
snapshot
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
snapshot
.
User
.
TotalRecharged
,
RPMLimit
:
snapshot
.
User
.
RPMLimit
,
UserGroupRPMOverride
:
snapshot
.
User
.
UserGroupRPMOverride
,
},
}
if
snapshot
.
Group
!=
nil
{
...
...
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch
:
snapshot
.
Group
.
AllowMessagesDispatch
,
DefaultMappedModel
:
snapshot
.
Group
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
snapshot
.
Group
.
MessagesDispatchModelConfig
,
RPMLimit
:
snapshot
.
Group
.
RPMLimit
,
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
...
...
backend/internal/service/api_key_service_cache_test.go
View file @
6b0cf466
...
...
@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
},
}
snapshot
:=
svc
.
snapshotFromAPIKey
(
apiKey
)
snapshot
:=
svc
.
snapshotFromAPIKey
(
context
.
Background
(),
apiKey
)
roundTrip
:=
svc
.
snapshotToAPIKey
(
apiKey
.
Key
,
snapshot
)
require
.
NotNil
(
t
,
roundTrip
)
...
...
backend/internal/service/auth_service.go
View file @
6b0cf466
...
...
@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
grantPlan
:=
s
.
resolveSignupGrantPlan
(
ctx
,
"email"
)
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
var
defaultRPMLimit
int
if
s
.
settingService
!=
nil
{
defaultRPMLimit
=
s
.
settingService
.
GetDefaultUserRPMLimit
(
ctx
)
}
// 创建用户
user
:=
&
User
{
Email
:
email
,
...
...
@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Role
:
RoleUser
,
Balance
:
grantPlan
.
Balance
,
Concurrency
:
grantPlan
.
Concurrency
,
RPMLimit
:
defaultRPMLimit
,
Status
:
StatusActive
,
}
...
...
@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
signupSource
:=
inferLegacySignupSource
(
email
)
grantPlan
:=
s
.
resolveSignupGrantPlan
(
ctx
,
signupSource
)
var
defaultRPMLimit
int
if
s
.
settingService
!=
nil
{
defaultRPMLimit
=
s
.
settingService
.
GetDefaultUserRPMLimit
(
ctx
)
}
newUser
:=
&
User
{
Email
:
email
,
...
...
@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Role
:
RoleUser
,
Balance
:
grantPlan
.
Balance
,
Concurrency
:
grantPlan
.
Concurrency
,
RPMLimit
:
defaultRPMLimit
,
Status
:
StatusActive
,
SignupSource
:
signupSource
,
}
...
...
@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
signupSource
:=
inferLegacySignupSource
(
email
)
grantPlan
:=
s
.
resolveSignupGrantPlan
(
ctx
,
signupSource
)
var
defaultRPMLimit
int
if
s
.
settingService
!=
nil
{
defaultRPMLimit
=
s
.
settingService
.
GetDefaultUserRPMLimit
(
ctx
)
}
newUser
:=
&
User
{
Email
:
email
,
...
...
@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Role
:
RoleUser
,
Balance
:
grantPlan
.
Balance
,
Concurrency
:
grantPlan
.
Concurrency
,
RPMLimit
:
defaultRPMLimit
,
Status
:
StatusActive
,
SignupSource
:
signupSource
,
}
...
...
backend/internal/service/billing_cache_service.go
View file @
6b0cf466
...
...
@@ -20,6 +20,9 @@ import (
var
(
ErrSubscriptionInvalid
=
infraerrors
.
Forbidden
(
"SUBSCRIPTION_INVALID"
,
"subscription is invalid or expired"
)
ErrBillingServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"BILLING_SERVICE_ERROR"
,
"Billing service temporarily unavailable. Please retry later."
)
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded
=
infraerrors
.
TooManyRequests
(
"GROUP_RPM_EXCEEDED"
,
"group requests-per-minute limit exceeded"
)
ErrUserRPMExceeded
=
infraerrors
.
TooManyRequests
(
"USER_RPM_EXCEEDED"
,
"user requests-per-minute limit exceeded"
)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
...
...
@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo
UserRepository
subRepo
UserSubscriptionRepository
apiKeyRateLimitLoader
apiKeyRateLimitLoader
userRPMCache
UserRPMCache
userGroupRateRepo
UserGroupRateRepository
cfg
*
config
.
Config
circuitBreaker
*
billingCircuitBreaker
...
...
@@ -104,12 +109,22 @@ type BillingCacheService struct {
}
// NewBillingCacheService 创建计费缓存服务
func
NewBillingCacheService
(
cache
BillingCache
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
apiKeyRepo
APIKeyRepository
,
cfg
*
config
.
Config
)
*
BillingCacheService
{
func
NewBillingCacheService
(
cache
BillingCache
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
apiKeyRepo
APIKeyRepository
,
userRPMCache
UserRPMCache
,
userGroupRateRepo
UserGroupRateRepository
,
cfg
*
config
.
Config
,
)
*
BillingCacheService
{
svc
:=
&
BillingCacheService
{
cache
:
cache
,
userRepo
:
userRepo
,
subRepo
:
subRepo
,
apiKeyRateLimitLoader
:
apiKeyRepo
,
userRPMCache
:
userRPMCache
,
userGroupRateRepo
:
userGroupRateRepo
,
cfg
:
cfg
,
}
svc
.
circuitBreaker
=
newBillingCircuitBreaker
(
cfg
.
Billing
.
CircuitBreaker
)
...
...
@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
}
}
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
if
err
:=
s
.
checkRPM
(
ctx
,
user
,
group
);
err
!=
nil
{
return
err
}
return
nil
}
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
//
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
//
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
func
(
s
*
BillingCacheService
)
checkRPM
(
ctx
context
.
Context
,
user
*
User
,
group
*
Group
)
error
{
if
s
==
nil
||
s
.
userRPMCache
==
nil
||
user
==
nil
{
return
nil
}
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
if
group
!=
nil
{
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
var
override
*
int
if
user
.
UserGroupRPMOverride
!=
nil
{
override
=
user
.
UserGroupRPMOverride
}
else
if
s
.
userGroupRateRepo
!=
nil
{
dbOverride
,
err
:=
s
.
userGroupRateRepo
.
GetRPMOverrideByUserAndGroup
(
ctx
,
user
.
ID
,
group
.
ID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: rpm override lookup failed for user=%d group=%d: %v"
,
user
.
ID
,
group
.
ID
,
err
,
)
}
else
{
override
=
dbOverride
}
}
if
override
!=
nil
{
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
if
*
override
>
0
{
count
,
incErr
:=
s
.
userRPMCache
.
IncrementUserGroupRPM
(
ctx
,
user
.
ID
,
group
.
ID
)
if
incErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: rpm increment (override) failed for user=%d group=%d: %v"
,
user
.
ID
,
group
.
ID
,
incErr
,
)
// fail-open
}
else
if
count
>
*
override
{
return
ErrGroupRPMExceeded
}
}
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
}
else
if
group
.
RPMLimit
>
0
{
// 无 override,检查 group.rpm_limit。
count
,
err
:=
s
.
userRPMCache
.
IncrementUserGroupRPM
(
ctx
,
user
.
ID
,
group
.
ID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: rpm increment (group) failed for user=%d group=%d: %v"
,
user
.
ID
,
group
.
ID
,
err
,
)
// fail-open
}
else
if
count
>
group
.
RPMLimit
{
return
ErrGroupRPMExceeded
}
}
}
// ── 第二层:用户级全局硬上限(始终生效) ──
if
user
.
RPMLimit
>
0
{
count
,
err
:=
s
.
userRPMCache
.
IncrementUserRPM
(
ctx
,
user
.
ID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: rpm increment (user) failed for user=%d: %v"
,
user
.
ID
,
err
,
)
return
nil
// fail-open
}
if
count
>
user
.
RPMLimit
{
return
ErrUserRPMExceeded
}
}
return
nil
}
...
...
backend/internal/service/billing_cache_service_rpm_test.go
0 → 100644
View file @
6b0cf466
//go:build unit
package
service
import
(
"context"
"errors"
"sync/atomic"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
type
userRPMCacheStub
struct
{
userGroupCalls
int32
userCalls
int32
userGroupCounts
[]
int
// 依次返回的计数值
userGroupErr
error
userCounts
[]
int
userErr
error
}
func
(
s
*
userRPMCacheStub
)
IncrementUserGroupRPM
(
_
context
.
Context
,
_
,
_
int64
)
(
int
,
error
)
{
idx
:=
int
(
atomic
.
AddInt32
(
&
s
.
userGroupCalls
,
1
))
-
1
if
s
.
userGroupErr
!=
nil
{
return
0
,
s
.
userGroupErr
}
if
idx
<
len
(
s
.
userGroupCounts
)
{
return
s
.
userGroupCounts
[
idx
],
nil
}
return
1
,
nil
}
func
(
s
*
userRPMCacheStub
)
IncrementUserRPM
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
idx
:=
int
(
atomic
.
AddInt32
(
&
s
.
userCalls
,
1
))
-
1
if
s
.
userErr
!=
nil
{
return
0
,
s
.
userErr
}
if
idx
<
len
(
s
.
userCounts
)
{
return
s
.
userCounts
[
idx
],
nil
}
return
1
,
nil
}
func
(
s
*
userRPMCacheStub
)
GetUserGroupRPM
(
_
context
.
Context
,
_
,
_
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
userRPMCacheStub
)
GetUserRPM
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
return
0
,
nil
}
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
type
rpmOverrideRepoStub
struct
{
UserGroupRateRepository
override
*
int
err
error
calls
int32
}
func
(
s
*
rpmOverrideRepoStub
)
GetRPMOverrideByUserAndGroup
(
_
context
.
Context
,
_
,
_
int64
)
(
*
int
,
error
)
{
atomic
.
AddInt32
(
&
s
.
calls
,
1
)
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
return
s
.
override
,
nil
}
func
newBillingServiceForRPM
(
t
*
testing
.
T
,
cache
UserRPMCache
,
rateRepo
UserGroupRateRepository
)
*
BillingCacheService
{
t
.
Helper
()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc
:=
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
cache
,
rateRepo
,
&
config
.
Config
{})
t
.
Cleanup
(
svc
.
Stop
)
return
svc
}
func
TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup
(
t
*
testing
.
T
)
{
override
:=
2
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
cache
:=
&
userRPMCacheStub
{
userGroupCounts
:
[]
int
{
1
,
2
,
3
}}
repo
:=
&
rpmOverrideRepoStub
{
override
:
&
override
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
100
}
// 全局上限设高,不干扰 override 测试
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
100
}
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
ErrorIs
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
),
ErrGroupRPMExceeded
)
require
.
EqualValues
(
t
,
3
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
),
"override 命中分支应走 user-group 计数"
)
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
require
.
EqualValues
(
t
,
2
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
),
"override 超限前 user 计数器应被调用"
)
require
.
EqualValues
(
t
,
3
,
atomic
.
LoadInt32
(
&
repo
.
calls
))
}
func
TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap
(
t
*
testing
.
T
)
{
override
:=
100
// override 很高
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
cache
:=
&
userRPMCacheStub
{
userCounts
:
[]
int
{
1
,
2
,
3
}}
repo
:=
&
rpmOverrideRepoStub
{
override
:
&
override
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
2
}
// 全局硬上限=2,应覆盖 override=100
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
100
}
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
ErrorIs
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
),
ErrUserRPMExceeded
,
"user 全局硬上限应优先于 override"
)
}
func
TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies
(
t
*
testing
.
T
)
{
zero
:=
0
// user 计数: 依次返回 1..6
cache
:=
&
userRPMCacheStub
{
userCounts
:
[]
int
{
1
,
2
,
3
,
4
,
5
,
6
}}
repo
:=
&
rpmOverrideRepoStub
{
override
:
&
zero
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
5
}
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
100
}
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
for
i
:=
0
;
i
<
5
;
i
++
{
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
),
"request %d should pass"
,
i
+
1
)
}
require
.
ErrorIs
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
),
ErrUserRPMExceeded
,
"override=0 跳过分组但 user 全局上限仍应生效"
)
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
),
"override=0 不应触发分组计数器"
)
require
.
EqualValues
(
t
,
6
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
),
"user 计数器应被调用"
)
}
func
TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited
(
t
*
testing
.
T
)
{
zero
:=
0
cache
:=
&
userRPMCacheStub
{}
repo
:=
&
rpmOverrideRepoStub
{
override
:
&
zero
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
0
}
// user 也不限
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
100
}
for
i
:=
0
;
i
<
50
;
i
++
{
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
}
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
),
"override=0 不触发分组计数"
)
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
),
"user.RPMLimit=0 也不触发用户计数"
)
}
func
TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup
(
t
*
testing
.
T
)
{
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
cache
:=
&
userRPMCacheStub
{
userGroupCounts
:
[]
int
{
5
,
6
}}
repo
:=
&
rpmOverrideRepoStub
{
override
:
nil
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
999
}
// 全局上限很高,group 先超
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
5
}
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
// ug=5, user=1, 都没超
require
.
ErrorIs
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
),
ErrGroupRPMExceeded
)
// ug=6 > 5
require
.
EqualValues
(
t
,
2
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
))
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
require
.
EqualValues
(
t
,
1
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
),
"group 未超时 user 也应检查;group 超时直接返回"
)
}
func
TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup
(
t
*
testing
.
T
)
{
cache
:=
&
userRPMCacheStub
{
userGroupCounts
:
[]
int
{
3
}}
repo
:=
&
rpmOverrideRepoStub
{
err
:
errors
.
New
(
"db down"
)}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
0
}
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
10
}
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
EqualValues
(
t
,
1
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
))
require
.
EqualValues
(
t
,
1
,
atomic
.
LoadInt32
(
&
repo
.
calls
))
}
func
TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited
(
t
*
testing
.
T
)
{
cache
:=
&
userRPMCacheStub
{
userCounts
:
[]
int
{
1
,
2
,
3
}}
repo
:=
&
rpmOverrideRepoStub
{
override
:
nil
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
2
}
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
0
}
// 分组未设限
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
ErrorIs
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
),
ErrUserRPMExceeded
)
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
),
"group 未设限时不应 INCR user-group 键"
)
require
.
EqualValues
(
t
,
3
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
))
}
func
TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop
(
t
*
testing
.
T
)
{
cache
:=
&
userRPMCacheStub
{}
repo
:=
&
rpmOverrideRepoStub
{
override
:
nil
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
0
}
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
0
}
for
i
:=
0
;
i
<
10
;
i
++
{
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
}
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
))
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
))
}
func
TestBillingCacheService_CheckRPM_RedisErrorFailOpen
(
t
*
testing
.
T
)
{
cache
:=
&
userRPMCacheStub
{
userGroupErr
:
errors
.
New
(
"redis unavailable"
)}
repo
:=
&
rpmOverrideRepoStub
{
override
:
nil
}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
0
}
group
:=
&
Group
{
ID
:
10
,
RPMLimit
:
5
}
// Redis 故障时应 fail-open,不拒绝请求
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
group
))
require
.
EqualValues
(
t
,
1
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
))
}
func
TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly
(
t
*
testing
.
T
)
{
cache
:=
&
userRPMCacheStub
{
userCounts
:
[]
int
{
1
,
2
,
3
}}
repo
:=
&
rpmOverrideRepoStub
{}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
user
:=
&
User
{
ID
:
1
,
RPMLimit
:
2
}
// 无 group(纯用户级限流场景),不应查询 rpm_override。
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
nil
))
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
nil
))
require
.
ErrorIs
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
user
,
nil
),
ErrUserRPMExceeded
)
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
repo
.
calls
),
"无 group 时不应查询 rpm_override"
)
require
.
EqualValues
(
t
,
3
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
))
}
func
TestBillingCacheService_CheckRPM_NilUserIsNoop
(
t
*
testing
.
T
)
{
cache
:=
&
userRPMCacheStub
{}
repo
:=
&
rpmOverrideRepoStub
{}
svc
:=
newBillingServiceForRPM
(
t
,
cache
,
repo
)
require
.
NoError
(
t
,
svc
.
checkRPM
(
context
.
Background
(),
nil
,
&
Group
{
ID
:
1
,
RPMLimit
:
10
}))
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userGroupCalls
))
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
cache
.
userCalls
))
require
.
EqualValues
(
t
,
0
,
atomic
.
LoadInt32
(
&
repo
.
calls
))
}
backend/internal/service/billing_cache_service_singleflight_test.go
View file @
6b0cf466
...
...
@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay
:
80
*
time
.
Millisecond
,
balance
:
12.34
,
}
svc
:=
NewBillingCacheService
(
cache
,
userRepo
,
nil
,
nil
,
&
config
.
Config
{})
svc
:=
NewBillingCacheService
(
cache
,
userRepo
,
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
t
.
Cleanup
(
svc
.
Stop
)
const
goroutines
=
16
...
...
backend/internal/service/billing_cache_service_test.go
View file @
6b0cf466
...
...
@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func
TestBillingCacheServiceQueueHighLoad
(
t
*
testing
.
T
)
{
cache
:=
&
billingCacheWorkerStub
{}
svc
:=
NewBillingCacheService
(
cache
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
svc
:=
NewBillingCacheService
(
cache
,
nil
,
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
t
.
Cleanup
(
svc
.
Stop
)
start
:=
time
.
Now
()
...
...
@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func
TestBillingCacheServiceEnqueueAfterStopReturnsFalse
(
t
*
testing
.
T
)
{
cache
:=
&
billingCacheWorkerStub
{}
svc
:=
NewBillingCacheService
(
cache
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
svc
:=
NewBillingCacheService
(
cache
,
nil
,
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
svc
.
Stop
()
enqueued
:=
svc
.
enqueueCacheWrite
(
cacheWriteTask
{
...
...
backend/internal/service/domain_constants.go
View file @
6b0cf466
...
...
@@ -170,9 +170,10 @@ const (
SettingKeyCustomEndpoints
=
"custom_endpoints"
// 自定义端点列表(JSON 数组)
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
SettingKeyDefaultBalance
=
"default_balance"
// 新用户默认余额
SettingKeyDefaultSubscriptions
=
"default_subscriptions"
// 新用户默认订阅列表(JSON)
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
SettingKeyDefaultBalance
=
"default_balance"
// 新用户默认余额
SettingKeyDefaultSubscriptions
=
"default_subscriptions"
// 新用户默认订阅列表(JSON)
SettingKeyDefaultUserRPMLimit
=
"default_user_rpm_limit"
// 新用户默认 RPM 限制(0 = 不限制)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance
=
"auth_source_default_email_balance"
...
...
backend/internal/service/group.go
View file @
6b0cf466
...
...
@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel
string
MessagesDispatchModelConfig
OpenAIMessagesDispatchModelConfig
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
RPMLimit
int
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
backend/internal/service/setting_service.go
View file @
6b0cf466
...
...
@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
SettingKeyDefaultBalance
]
=
strconv
.
FormatFloat
(
settings
.
DefaultBalance
,
'f'
,
8
,
64
)
updates
[
SettingKeyDefaultUserRPMLimit
]
=
strconv
.
Itoa
(
settings
.
DefaultUserRPMLimit
)
defaultSubsJSON
,
err
:=
json
.
Marshal
(
settings
.
DefaultSubscriptions
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshal default subscriptions: %w"
,
err
)
...
...
@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return
s
.
cfg
.
Default
.
UserBalance
}
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
func
(
s
*
SettingService
)
GetDefaultUserRPMLimit
(
ctx
context
.
Context
)
int
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyDefaultUserRPMLimit
)
if
err
!=
nil
||
value
==
""
{
return
0
}
if
v
,
err
:=
strconv
.
Atoi
(
value
);
err
==
nil
&&
v
>=
0
{
return
v
}
return
0
}
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
func
(
s
*
SettingService
)
GetDefaultSubscriptions
(
ctx
context
.
Context
)
[]
DefaultSubscriptionSetting
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyDefaultSubscriptions
)
...
...
@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath
:
""
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeyDefaultUserRPMLimit
:
"0"
,
SettingKeyDefaultSubscriptions
:
"[]"
,
SettingKeyAuthSourceDefaultEmailBalance
:
"0"
,
SettingKeyAuthSourceDefaultEmailConcurrency
:
"5"
,
...
...
@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
.
DefaultConcurrency
=
s
.
cfg
.
Default
.
UserConcurrency
}
if
rpm
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyDefaultUserRPMLimit
]);
err
==
nil
&&
rpm
>=
0
{
result
.
DefaultUserRPMLimit
=
rpm
}
// 解析浮点数类型
if
balance
,
err
:=
strconv
.
ParseFloat
(
settings
[
SettingKeyDefaultBalance
],
64
);
err
==
nil
{
result
.
DefaultBalance
=
balance
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment