Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6b0cf466
Unverified
Commit
6b0cf466
authored
Apr 23, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 23, 2026
Browse files
Merge pull request #1815 from james-6-23/feat_rpm
feat(rpm): RPM 限流模块优化
parents
ef967d8f
dc5d42ad
Changes
79
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/admin/group_handler.go
View file @
6b0cf466
...
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
...
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
RequirePrivacySet
bool
`json:"require_privacy_set"`
RequirePrivacySet
bool
`json:"require_privacy_set"`
DefaultMappedModel
string
`json:"default_mapped_model"`
DefaultMappedModel
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
MessagesDispatchModelConfig
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制)
RPMLimit
int
`json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
}
...
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
...
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
RequirePrivacySet
*
bool
`json:"require_privacy_set"`
RequirePrivacySet
*
bool
`json:"require_privacy_set"`
DefaultMappedModel
*
string
`json:"default_mapped_model"`
DefaultMappedModel
*
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
*
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
MessagesDispatchModelConfig
*
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
RPMLimit
*
int
`json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
}
...
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
...
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet
:
req
.
RequirePrivacySet
,
RequirePrivacySet
:
req
.
RequirePrivacySet
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
RPMLimit
:
req
.
RPMLimit
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
...
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet
:
req
.
RequirePrivacySet
,
RequirePrivacySet
:
req
.
RequirePrivacySet
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
RPMLimit
:
req
.
RPMLimit
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
...
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Rate multipliers updated successfully"
})
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Rate multipliers updated successfully"
})
}
}
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
type
BatchSetGroupRPMOverridesRequest
struct
{
Entries
[]
service
.
GroupRPMOverrideInput
`json:"entries" binding:"required"`
}
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
// PUT /api/v1/admin/groups/:id/rpm-overrides
func
(
h
*
GroupHandler
)
BatchSetGroupRPMOverrides
(
c
*
gin
.
Context
)
{
groupID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid group ID"
)
return
}
var
req
BatchSetGroupRPMOverridesRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
err
:=
h
.
adminService
.
BatchSetGroupRPMOverrides
(
c
.
Request
.
Context
(),
groupID
,
req
.
Entries
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"RPM overrides updated successfully"
})
}
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
// DELETE /api/v1/admin/groups/:id/rpm-overrides
func
(
h
*
GroupHandler
)
ClearGroupRPMOverrides
(
c
*
gin
.
Context
)
{
groupID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid group ID"
)
return
}
if
err
:=
h
.
adminService
.
ClearGroupRPMOverrides
(
c
.
Request
.
Context
(),
groupID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"RPM overrides cleared successfully"
})
}
// UpdateSortOrderRequest represents the request to update group sort orders
// UpdateSortOrderRequest represents the request to update group sort orders
type
UpdateSortOrderRequest
struct
{
type
UpdateSortOrderRequest
struct
{
Updates
[]
struct
{
Updates
[]
struct
{
...
...
backend/internal/handler/admin/setting_handler.go
View file @
6b0cf466
...
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
...
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
settings
.
CustomEndpoints
),
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
settings
.
CustomEndpoints
),
DefaultConcurrency
:
settings
.
DefaultConcurrency
,
DefaultConcurrency
:
settings
.
DefaultConcurrency
,
DefaultBalance
:
settings
.
DefaultBalance
,
DefaultBalance
:
settings
.
DefaultBalance
,
DefaultUserRPMLimit
:
settings
.
DefaultUserRPMLimit
,
DefaultSubscriptions
:
defaultSubscriptions
,
DefaultSubscriptions
:
defaultSubscriptions
,
EnableModelFallback
:
settings
.
EnableModelFallback
,
EnableModelFallback
:
settings
.
EnableModelFallback
,
FallbackModelAnthropic
:
settings
.
FallbackModelAnthropic
,
FallbackModelAnthropic
:
settings
.
FallbackModelAnthropic
,
...
@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
...
@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
// 默认配置
// 默认配置
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
DefaultBalance
float64
`json:"default_balance"`
DefaultUserRPMLimit
int
`json:"default_user_rpm_limit"`
DefaultSubscriptions
[]
dto
.
DefaultSubscriptionSetting
`json:"default_subscriptions"`
DefaultSubscriptions
[]
dto
.
DefaultSubscriptionSetting
`json:"default_subscriptions"`
AuthSourceDefaultEmailBalance
*
float64
`json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailBalance
*
float64
`json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency
*
int
`json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailConcurrency
*
int
`json:"auth_source_default_email_concurrency"`
...
@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
...
@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints
:
customEndpointsJSON
,
CustomEndpoints
:
customEndpointsJSON
,
DefaultConcurrency
:
req
.
DefaultConcurrency
,
DefaultConcurrency
:
req
.
DefaultConcurrency
,
DefaultBalance
:
req
.
DefaultBalance
,
DefaultBalance
:
req
.
DefaultBalance
,
DefaultUserRPMLimit
:
req
.
DefaultUserRPMLimit
,
DefaultSubscriptions
:
defaultSubscriptions
,
DefaultSubscriptions
:
defaultSubscriptions
,
EnableModelFallback
:
req
.
EnableModelFallback
,
EnableModelFallback
:
req
.
EnableModelFallback
,
FallbackModelAnthropic
:
req
.
FallbackModelAnthropic
,
FallbackModelAnthropic
:
req
.
FallbackModelAnthropic
,
...
@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
...
@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
updatedSettings
.
CustomEndpoints
),
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
updatedSettings
.
CustomEndpoints
),
DefaultConcurrency
:
updatedSettings
.
DefaultConcurrency
,
DefaultConcurrency
:
updatedSettings
.
DefaultConcurrency
,
DefaultBalance
:
updatedSettings
.
DefaultBalance
,
DefaultBalance
:
updatedSettings
.
DefaultBalance
,
DefaultUserRPMLimit
:
updatedSettings
.
DefaultUserRPMLimit
,
DefaultSubscriptions
:
updatedDefaultSubscriptions
,
DefaultSubscriptions
:
updatedDefaultSubscriptions
,
EnableModelFallback
:
updatedSettings
.
EnableModelFallback
,
EnableModelFallback
:
updatedSettings
.
EnableModelFallback
,
FallbackModelAnthropic
:
updatedSettings
.
FallbackModelAnthropic
,
FallbackModelAnthropic
:
updatedSettings
.
FallbackModelAnthropic
,
...
...
backend/internal/handler/admin/user_handler.go
View file @
6b0cf466
...
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
...
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
Notes
string
`json:"notes"`
Notes
string
`json:"notes"`
Balance
float64
`json:"balance"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
Concurrency
int
`json:"concurrency"`
RPMLimit
int
`json:"rpm_limit"`
AllowedGroups
[]
int64
`json:"allowed_groups"`
AllowedGroups
[]
int64
`json:"allowed_groups"`
}
}
...
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
...
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
Notes
*
string
`json:"notes"`
Notes
*
string
`json:"notes"`
Balance
*
float64
`json:"balance"`
Balance
*
float64
`json:"balance"`
Concurrency
*
int
`json:"concurrency"`
Concurrency
*
int
`json:"concurrency"`
RPMLimit
*
int
`json:"rpm_limit"`
Status
string
`json:"status" binding:"omitempty,oneof=active disabled"`
Status
string
`json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups
*
[]
int64
`json:"allowed_groups"`
AllowedGroups
*
[]
int64
`json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// GroupRates 用户专属分组倍率配置
...
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
...
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
Notes
:
req
.
Notes
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
Concurrency
:
req
.
Concurrency
,
RPMLimit
:
req
.
RPMLimit
,
AllowedGroups
:
req
.
AllowedGroups
,
AllowedGroups
:
req
.
AllowedGroups
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
...
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Notes
:
req
.
Notes
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
Concurrency
:
req
.
Concurrency
,
RPMLimit
:
req
.
RPMLimit
,
Status
:
req
.
Status
,
Status
:
req
.
Status
,
AllowedGroups
:
req
.
AllowedGroups
,
AllowedGroups
:
req
.
AllowedGroups
,
GroupRates
:
req
.
GroupRates
,
GroupRates
:
req
.
GroupRates
,
...
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
...
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys"
:
result
.
MigratedKeys
,
"migrated_keys"
:
result
.
MigratedKeys
,
})
})
}
}
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
// GET /api/v1/admin/users/:id/rpm-status
func
(
h
*
UserHandler
)
GetUserRPMStatus
(
c
*
gin
.
Context
)
{
userID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid user ID"
)
return
}
status
,
err
:=
h
.
adminService
.
GetUserRPMStatus
(
c
.
Request
.
Context
(),
userID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
status
)
}
backend/internal/handler/dto/mappers.go
View file @
6b0cf466
...
@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
...
@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
NotifyEmailEntriesFromService
(
u
.
BalanceNotifyExtraEmails
),
BalanceNotifyExtraEmails
:
NotifyEmailEntriesFromService
(
u
.
BalanceNotifyExtraEmails
),
TotalRecharged
:
u
.
TotalRecharged
,
TotalRecharged
:
u
.
TotalRecharged
,
RPMLimit
:
u
.
RPMLimit
,
}
}
}
}
...
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
...
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
AllowMessagesDispatch
:
g
.
AllowMessagesDispatch
,
AllowMessagesDispatch
:
g
.
AllowMessagesDispatch
,
RequireOAuthOnly
:
g
.
RequireOAuthOnly
,
RequireOAuthOnly
:
g
.
RequireOAuthOnly
,
RequirePrivacySet
:
g
.
RequirePrivacySet
,
RequirePrivacySet
:
g
.
RequirePrivacySet
,
RPMLimit
:
g
.
RPMLimit
,
CreatedAt
:
g
.
CreatedAt
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
}
...
...
backend/internal/handler/dto/settings.go
View file @
6b0cf466
...
@@ -108,6 +108,7 @@ type SystemSettings struct {
...
@@ -108,6 +108,7 @@ type SystemSettings struct {
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
DefaultBalance
float64
`json:"default_balance"`
DefaultUserRPMLimit
int
`json:"default_user_rpm_limit"`
DefaultSubscriptions
[]
DefaultSubscriptionSetting
`json:"default_subscriptions"`
DefaultSubscriptions
[]
DefaultSubscriptionSetting
`json:"default_subscriptions"`
// Model fallback configuration
// Model fallback configuration
...
...
backend/internal/handler/dto/types.go
View file @
6b0cf466
...
@@ -26,6 +26,9 @@ type User struct {
...
@@ -26,6 +26,9 @@ type User struct {
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails"`
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails"`
TotalRecharged
float64
`json:"total_recharged"`
TotalRecharged
float64
`json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
RPMLimit
int
`json:"rpm_limit"`
APIKeys
[]
APIKey
`json:"api_keys,omitempty"`
APIKeys
[]
APIKey
`json:"api_keys,omitempty"`
Subscriptions
[]
UserSubscription
`json:"subscriptions,omitempty"`
Subscriptions
[]
UserSubscription
`json:"subscriptions,omitempty"`
}
}
...
@@ -108,6 +111,9 @@ type Group struct {
...
@@ -108,6 +111,9 @@ type Group struct {
RequireOAuthOnly
bool
`json:"require_oauth_only"`
RequireOAuthOnly
bool
`json:"require_oauth_only"`
RequirePrivacySet
bool
`json:"require_privacy_set"`
RequirePrivacySet
bool
`json:"require_privacy_set"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
RPMLimit
int
`json:"rpm_limit"`
CreatedAt
time
.
Time
`json:"created_at"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
}
...
...
backend/internal/handler/gateway_handler.go
View file @
6b0cf466
...
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
// 2. 【新增】Wait后二次检查余额/订阅
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gateway.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"gateway.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
return
}
}
...
@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
fallbackAPIKey
:=
cloneAPIKeyWithGroup
(
apiKey
,
fallbackGroup
)
fallbackAPIKey
:=
cloneAPIKeyWithGroup
(
apiKey
,
fallbackGroup
)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
fallbackAPIKey
.
User
,
fallbackAPIKey
,
fallbackGroup
,
nil
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
fallbackAPIKey
.
User
,
fallbackAPIKey
,
fallbackGroup
,
nil
);
err
!=
nil
{
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
return
}
}
...
@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
...
@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
// 【注意】不计算并发,但需要校验订阅/余额
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
errorResponse
(
c
,
status
,
code
,
message
)
h
.
errorResponse
(
c
,
status
,
code
,
message
)
return
return
}
}
...
@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
...
@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c
.
JSON
(
http
.
StatusOK
,
response
)
c
.
JSON
(
http
.
StatusOK
,
response
)
}
}
func
billingErrorDetails
(
err
error
)
(
status
int
,
code
,
message
string
)
{
func
billingErrorDetails
(
err
error
)
(
status
int
,
code
,
message
string
,
retryAfter
int
)
{
if
errors
.
Is
(
err
,
service
.
ErrBillingServiceUnavailable
)
{
if
errors
.
Is
(
err
,
service
.
ErrBillingServiceUnavailable
)
{
msg
:=
pkgerrors
.
Message
(
err
)
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
if
msg
==
""
{
msg
=
"Billing service temporarily unavailable. Please retry later."
msg
=
"Billing service temporarily unavailable. Please retry later."
}
}
return
http
.
StatusServiceUnavailable
,
"billing_service_error"
,
msg
return
http
.
StatusServiceUnavailable
,
"billing_service_error"
,
msg
,
0
}
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit5hExceeded
)
{
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit5hExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
0
}
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit1dExceeded
)
{
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit1dExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
0
}
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit7dExceeded
)
{
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit7dExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
0
}
// 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
if
errors
.
Is
(
err
,
service
.
ErrGroupRPMExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrUserRPMExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
retrySeconds
:=
60
-
int
(
time
.
Now
()
.
Unix
()
%
60
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
retrySeconds
}
}
msg
:=
pkgerrors
.
Message
(
err
)
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
if
msg
==
""
{
...
@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
...
@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
)
.
Warn
(
"gateway.billing_error_missing_message"
)
)
.
Warn
(
"gateway.billing_error_missing_message"
)
msg
=
"Billing error"
msg
=
"Billing error"
}
}
return
http
.
StatusForbidden
,
"billing_error"
,
msg
return
http
.
StatusForbidden
,
"billing_error"
,
msg
,
0
}
}
func
(
h
*
GatewayHandler
)
metadataBridgeEnabled
()
bool
{
func
(
h
*
GatewayHandler
)
metadataBridgeEnabled
()
bool
{
...
...
backend/internal/handler/gateway_handler_billing_error_test.go
0 → 100644
View file @
6b0cf466
package
handler
import
(
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests
(
t
*
testing
.
T
)
{
status
,
code
,
msg
,
retryAfter
:=
billingErrorDetails
(
service
.
ErrGroupRPMExceeded
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
status
)
require
.
Equal
(
t
,
"rate_limit_exceeded"
,
code
)
require
.
NotEmpty
(
t
,
msg
)
require
.
Greater
(
t
,
retryAfter
,
0
,
"RPM exceeded should return positive Retry-After"
)
require
.
LessOrEqual
(
t
,
retryAfter
,
60
)
}
func
TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests
(
t
*
testing
.
T
)
{
status
,
code
,
msg
,
retryAfter
:=
billingErrorDetails
(
service
.
ErrUserRPMExceeded
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
status
)
require
.
Equal
(
t
,
"rate_limit_exceeded"
,
code
)
require
.
NotEmpty
(
t
,
msg
)
require
.
Greater
(
t
,
retryAfter
,
0
,
"RPM exceeded should return positive Retry-After"
)
require
.
LessOrEqual
(
t
,
retryAfter
,
60
)
}
func
TestBillingErrorDetails_APIKeyRateLimitStillMaps
(
t
*
testing
.
T
)
{
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
for
_
,
err
:=
range
[]
error
{
service
.
ErrAPIKeyRateLimit5hExceeded
,
service
.
ErrAPIKeyRateLimit1dExceeded
,
service
.
ErrAPIKeyRateLimit7dExceeded
,
}
{
status
,
code
,
_
,
_
:=
billingErrorDetails
(
err
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
status
,
"status for %v"
,
err
)
require
.
Equal
(
t
,
"rate_limit_exceeded"
,
code
)
}
}
func
TestBillingErrorDetails_BillingServiceUnavailableMapsTo503
(
t
*
testing
.
T
)
{
status
,
code
,
_
,
retryAfter
:=
billingErrorDetails
(
service
.
ErrBillingServiceUnavailable
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
status
)
require
.
Equal
(
t
,
"billing_service_error"
,
code
)
require
.
Equal
(
t
,
0
,
retryAfter
,
"non-RPM errors should not set Retry-After"
)
}
func
TestBillingErrorDetails_UnknownErrorFallsBackTo403
(
t
*
testing
.
T
)
{
status
,
code
,
msg
,
_
:=
billingErrorDetails
(
service
.
ErrInsufficientBalance
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
status
)
require
.
Equal
(
t
,
"billing_error"
,
code
)
require
.
NotEmpty
(
t
,
msg
)
}
backend/internal/handler/gateway_handler_chat_completions.go
View file @
6b0cf466
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"errors"
"errors"
"net/http"
"net/http"
"strconv"
"time"
"time"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
...
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 2. Re-check billing
// 2. Re-check billing
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gateway.cc.billing_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"gateway.cc.billing_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
chatCompletionsErrorResponse
(
c
,
status
,
code
,
message
)
h
.
chatCompletionsErrorResponse
(
c
,
status
,
code
,
message
)
return
return
}
}
...
...
backend/internal/handler/gateway_handler_responses.go
View file @
6b0cf466
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"errors"
"errors"
"net/http"
"net/http"
"strconv"
"time"
"time"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
...
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing
// 2. Re-check billing
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gateway.responses.billing_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"gateway.responses.billing_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
responsesErrorResponse
(
c
,
status
,
code
,
message
)
h
.
responsesErrorResponse
(
c
,
status
,
code
,
message
)
return
return
}
}
...
...
backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
View file @
6b0cf466
...
@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
...
@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
billingCacheSvc
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
cfg
)
billingCacheSvc
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
concurrencySvc
:=
service
.
NewConcurrencyService
(
&
fakeConcurrencyCache
{})
concurrencySvc
:=
service
.
NewConcurrencyService
(
&
fakeConcurrencyCache
{})
concurrencyHelper
:=
NewConcurrencyHelper
(
concurrencySvc
,
SSEPingFormatClaude
,
0
)
concurrencyHelper
:=
NewConcurrencyHelper
(
concurrencySvc
,
SSEPingFormatClaude
,
0
)
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
6b0cf466
...
@@ -9,6 +9,7 @@ import (
...
@@ -9,6 +9,7 @@ import (
"errors"
"errors"
"net/http"
"net/http"
"regexp"
"regexp"
"strconv"
"strings"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/domain"
...
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
// 2) billing eligibility check (after wait)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gemini.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"gemini.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
_
,
message
:=
billingErrorDetails
(
err
)
status
,
_
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
googleError
(
c
,
status
,
message
)
googleError
(
c
,
status
,
message
)
return
return
}
}
...
...
backend/internal/handler/openai_chat_completions.go
View file @
6b0cf466
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"errors"
"errors"
"net/http"
"net/http"
"strconv"
"time"
"time"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
...
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai_chat_completions.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"openai_chat_completions.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
return
}
}
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
6b0cf466
...
@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
...
@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
// 2. Re-check billing eligibility after wait
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"openai.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
return
}
}
...
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
...
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai_messages.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"openai_messages.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
anthropicStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
h
.
anthropicStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
return
}
}
...
...
backend/internal/handler/openai_images.go
View file @
6b0cf466
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"errors"
"errors"
"net/http"
"net/http"
"strconv"
"strings"
"strings"
"time"
"time"
...
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
...
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai.images.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
reqLog
.
Info
(
"openai.images.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
return
}
}
...
...
backend/internal/repository/api_key_repo.go
View file @
6b0cf466
...
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
...
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user
.
FieldSignupSource
,
user
.
FieldSignupSource
,
user
.
FieldLastLoginAt
,
user
.
FieldLastLoginAt
,
user
.
FieldLastActiveAt
,
user
.
FieldLastActiveAt
,
user
.
FieldRpmLimit
,
)
)
})
.
})
.
WithGroup
(
func
(
q
*
dbent
.
GroupQuery
)
{
WithGroup
(
func
(
q
*
dbent
.
GroupQuery
)
{
...
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
...
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldAllowMessagesDispatch
,
group
.
FieldAllowMessagesDispatch
,
group
.
FieldDefaultMappedModel
,
group
.
FieldDefaultMappedModel
,
group
.
FieldMessagesDispatchModelConfig
,
group
.
FieldMessagesDispatchModelConfig
,
group
.
FieldRpmLimit
,
)
)
})
.
})
.
Only
(
ctx
)
Only
(
ctx
)
...
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
...
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
BalanceNotifyThresholdType
:
u
.
BalanceNotifyThresholdType
,
BalanceNotifyThresholdType
:
u
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
TotalRecharged
:
u
.
TotalRecharged
,
TotalRecharged
:
u
.
TotalRecharged
,
RPMLimit
:
u
.
RpmLimit
,
CreatedAt
:
u
.
CreatedAt
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
}
}
...
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
...
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet
:
g
.
RequirePrivacySet
,
RequirePrivacySet
:
g
.
RequirePrivacySet
,
DefaultMappedModel
:
g
.
DefaultMappedModel
,
DefaultMappedModel
:
g
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
g
.
MessagesDispatchModelConfig
,
MessagesDispatchModelConfig
:
g
.
MessagesDispatchModelConfig
,
RPMLimit
:
g
.
RpmLimit
,
CreatedAt
:
g
.
CreatedAt
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
}
...
...
backend/internal/repository/group_repo.go
View file @
6b0cf466
...
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
...
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetDefaultMappedModel
(
groupIn
.
DefaultMappedModel
)
.
SetDefaultMappedModel
(
groupIn
.
DefaultMappedModel
)
.
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
.
SetRpmLimit
(
groupIn
.
RPMLimit
)
// 设置模型路由配置
// 设置模型路由配置
if
groupIn
.
ModelRouting
!=
nil
{
if
groupIn
.
ModelRouting
!=
nil
{
...
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
...
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetDefaultMappedModel
(
groupIn
.
DefaultMappedModel
)
.
SetDefaultMappedModel
(
groupIn
.
DefaultMappedModel
)
.
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
.
SetRpmLimit
(
groupIn
.
RPMLimit
)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if
groupIn
.
DailyLimitUSD
!=
nil
{
if
groupIn
.
DailyLimitUSD
!=
nil
{
...
...
backend/internal/repository/user_group_rate_repo.go
View file @
6b0cf466
...
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
...
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql
sqlExecutor
sql
sqlExecutor
}
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
// NewUserGroupRateRepository 创建用户专属分组倍率
/RPM
仓储
func
NewUserGroupRateRepository
(
sqlDB
*
sql
.
DB
)
service
.
UserGroupRateRepository
{
func
NewUserGroupRateRepository
(
sqlDB
*
sql
.
DB
)
service
.
UserGroupRateRepository
{
return
&
userGroupRateRepository
{
sql
:
sqlDB
}
return
&
userGroupRateRepository
{
sql
:
sqlDB
}
}
}
// GetByUserID 获取用户
的
所有专属分组
倍率
// GetByUserID 获取用户所有专属分组
rate_multiplier(仅返回非 NULL 的条目)
func
(
r
*
userGroupRateRepository
)
GetByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
map
[
int64
]
float64
,
error
)
{
func
(
r
*
userGroupRateRepository
)
GetByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
map
[
int64
]
float64
,
error
)
{
query
:=
`SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
query
:=
`SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1
AND rate_multiplier IS NOT NULL
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
userID
)
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
userID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
...
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return
result
,
nil
return
result
,
nil
}
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
// 返回结构:map[userID]map[groupID]rate
func
(
r
*
userGroupRateRepository
)
GetByUserIDs
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
map
[
int64
]
float64
,
error
)
{
func
(
r
*
userGroupRateRepository
)
GetByUserIDs
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
map
[
int64
]
float64
,
error
)
{
result
:=
make
(
map
[
int64
]
map
[
int64
]
float64
,
len
(
userIDs
))
result
:=
make
(
map
[
int64
]
map
[
int64
]
float64
,
len
(
userIDs
))
if
len
(
userIDs
)
==
0
{
if
len
(
userIDs
)
==
0
{
...
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
...
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT user_id, group_id, rate_multiplier
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
WHERE user_id = ANY($1)
AND rate_multiplier IS NOT NULL
`
,
pq
.
Array
(
uniqueIDs
))
`
,
pq
.
Array
(
uniqueIDs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
...
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return
result
,
nil
return
result
,
nil
}
}
// GetByGroupID 获取指定分组下所有用户的专属
倍率
// GetByGroupID 获取指定分组下所有用户的专属
配置(rate 与 rpm_override 任一非 NULL 即返回)
func
(
r
*
userGroupRateRepository
)
GetByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
UserGroupRateEntry
,
error
)
{
func
(
r
*
userGroupRateRepository
)
GetByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
UserGroupRateEntry
,
error
)
{
query
:=
`
query
:=
`
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
, ugr.rpm_override
FROM user_group_rate_multipliers ugr
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
WHERE ugr.group_id = $1
...
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
...
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var
result
[]
service
.
UserGroupRateEntry
var
result
[]
service
.
UserGroupRateEntry
for
rows
.
Next
()
{
for
rows
.
Next
()
{
var
entry
service
.
UserGroupRateEntry
var
entry
service
.
UserGroupRateEntry
if
err
:=
rows
.
Scan
(
&
entry
.
UserID
,
&
entry
.
UserName
,
&
entry
.
UserEmail
,
&
entry
.
UserNotes
,
&
entry
.
UserStatus
,
&
entry
.
RateMultiplier
);
err
!=
nil
{
var
rate
sql
.
NullFloat64
var
rpm
sql
.
NullInt32
if
err
:=
rows
.
Scan
(
&
entry
.
UserID
,
&
entry
.
UserName
,
&
entry
.
UserEmail
,
&
entry
.
UserNotes
,
&
entry
.
UserStatus
,
&
rate
,
&
rpm
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
if
rate
.
Valid
{
v
:=
rate
.
Float64
entry
.
RateMultiplier
=
&
v
}
if
rpm
.
Valid
{
v
:=
int
(
rpm
.
Int32
)
entry
.
RPMOverride
=
&
v
}
result
=
append
(
result
,
entry
)
result
=
append
(
result
,
entry
)
}
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
if
err
:=
rows
.
Err
();
err
!=
nil
{
...
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
...
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return
result
,
nil
return
result
,
nil
}
}
// GetByUserAndGroup 获取用户在特定分组的专属
倍率
// GetByUserAndGroup 获取用户在特定分组的专属
rate_multiplier(NULL 返回 nil)
func
(
r
*
userGroupRateRepository
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
func
(
r
*
userGroupRateRepository
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
query
:=
`SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
query
:=
`SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var
rate
f
loat64
var
rate
sql
.
NullF
loat64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rate
)
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rate
)
if
err
==
sql
.
ErrNoRows
{
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
return
nil
,
nil
...
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
...
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
return
&
rate
,
nil
if
!
rate
.
Valid
{
return
nil
,
nil
}
v
:=
rate
.
Float64
return
&
v
,
nil
}
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
func
(
r
*
userGroupRateRepository
)
GetRPMOverrideByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
int
,
error
)
{
query
:=
`SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var
rpm
sql
.
NullInt32
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rpm
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
}
if
err
!=
nil
{
return
nil
,
err
}
if
!
rpm
.
Valid
{
return
nil
,
nil
}
v
:=
int
(
rpm
.
Int32
)
return
&
v
,
nil
}
}
// SyncUserGroupRates 同步用户的分组专属倍率
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func
(
r
*
userGroupRateRepository
)
SyncUserGroupRates
(
ctx
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
func
(
r
*
userGroupRateRepository
)
SyncUserGroupRates
(
ctx
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
if
len
(
rates
)
==
0
{
if
len
(
rates
)
==
0
{
// 如果传入空 map,删除该用户的所有专属倍率
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`
,
userID
);
err
!=
nil
{
return
err
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`
,
userID
)
return
err
return
err
}
}
// 分离需要删除和需要 upsert 的记录
var
clearGroupIDs
[]
int64
var
toDelete
[]
int64
upsertGroupIDs
:=
make
([]
int64
,
0
,
len
(
rates
))
upsertGroupIDs
:=
make
([]
int64
,
0
,
len
(
rates
))
upsertRates
:=
make
([]
float64
,
0
,
len
(
rates
))
upsertRates
:=
make
([]
float64
,
0
,
len
(
rates
))
for
groupID
,
rate
:=
range
rates
{
for
groupID
,
rate
:=
range
rates
{
if
rate
==
nil
{
if
rate
==
nil
{
toDelete
=
append
(
toDelete
,
groupID
)
clearGroupIDs
=
append
(
clearGroupIDs
,
groupID
)
}
else
{
}
else
{
upsertGroupIDs
=
append
(
upsertGroupIDs
,
groupID
)
upsertGroupIDs
=
append
(
upsertGroupIDs
,
groupID
)
upsertRates
=
append
(
upsertRates
,
*
rate
)
upsertRates
=
append
(
upsertRates
,
*
rate
)
}
}
}
}
// 删除指定的记录
if
len
(
clearGroupIDs
)
>
0
{
if
len
(
toDelete
)
>
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`
,
userID
,
pq
.
Array
(
clearGroupIDs
));
err
!=
nil
{
return
err
}
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)
AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
userID
,
pq
.
Array
(
toDelete
));
err
!=
nil
{
userID
,
pq
.
Array
(
clearGroupIDs
));
err
!=
nil
{
return
err
return
err
}
}
}
}
// Upsert 记录
now
:=
time
.
Now
()
if
len
(
upsertGroupIDs
)
>
0
{
if
len
(
upsertGroupIDs
)
>
0
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
SELECT
...
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
...
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return
nil
return
nil
}
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
// 语义:
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行:upsert rate_multiplier。
func
(
r
*
userGroupRateRepository
)
SyncGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRateMultiplierInput
)
error
{
func
(
r
*
userGroupRateRepository
)
SyncGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRateMultiplierInput
)
error
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
);
err
!=
nil
{
keepUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
keepUserIDs
=
append
(
keepUserIDs
,
e
.
UserID
)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if
len
(
keepUserIDs
)
==
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
}
else
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`
,
groupID
,
pq
.
Array
(
keepUserIDs
));
err
!=
nil
{
return
err
}
}
// 清空后若整行 NULL 则删除。
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
);
err
!=
nil
{
return
err
return
err
}
}
if
len
(
entries
)
==
0
{
if
len
(
entries
)
==
0
{
return
nil
return
nil
}
}
userIDs
:=
make
([]
int64
,
len
(
entries
))
userIDs
:=
make
([]
int64
,
len
(
entries
))
rates
:=
make
([]
float64
,
len
(
entries
))
rates
:=
make
([]
float64
,
len
(
entries
))
for
i
,
e
:=
range
entries
{
for
i
,
e
:=
range
entries
{
...
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
...
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return
err
return
err
}
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
// 语义:
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func
(
r
*
userGroupRateRepository
)
SyncGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRPMOverrideInput
)
error
{
keepUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
var
clearUserIDs
[]
int64
upsertUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
upsertValues
:=
make
([]
int32
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
keepUserIDs
=
append
(
keepUserIDs
,
e
.
UserID
)
if
e
.
RPMOverride
==
nil
{
clearUserIDs
=
append
(
clearUserIDs
,
e
.
UserID
)
}
else
{
upsertUserIDs
=
append
(
upsertUserIDs
,
e
.
UserID
)
upsertValues
=
append
(
upsertValues
,
int32
(
*
e
.
RPMOverride
))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if
len
(
keepUserIDs
)
==
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
}
else
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`
,
groupID
,
pq
.
Array
(
keepUserIDs
));
err
!=
nil
{
return
err
}
}
// 显式 clear 的行。
if
len
(
clearUserIDs
)
>
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`
,
groupID
,
pq
.
Array
(
clearUserIDs
));
err
!=
nil
{
return
err
}
}
// 清空后若整行 NULL 则删除。
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
);
err
!=
nil
{
return
err
}
if
len
(
upsertUserIDs
)
>
0
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`
,
groupID
,
now
,
pq
.
Array
(
upsertUserIDs
),
pq
.
Array
(
upsertValues
))
if
err
!=
nil
{
return
err
}
}
return
nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func
(
r
*
userGroupRateRepository
)
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
)
return
err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func
(
r
*
userGroupRateRepository
)
DeleteByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
error
{
func
(
r
*
userGroupRateRepository
)
DeleteByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
)
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
)
return
err
return
err
}
}
// DeleteByUserID 删除指定用户的所有专属
倍率
// DeleteByUserID 删除指定用户的所有专属
条目
func
(
r
*
userGroupRateRepository
)
DeleteByUserID
(
ctx
context
.
Context
,
userID
int64
)
error
{
func
(
r
*
userGroupRateRepository
)
DeleteByUserID
(
ctx
context
.
Context
,
userID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
return
err
return
err
...
...
backend/internal/repository/user_repo.go
View file @
6b0cf466
...
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
...
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource
(
userSignupSourceOrDefault
(
userIn
.
SignupSource
))
.
SetSignupSource
(
userSignupSourceOrDefault
(
userIn
.
SignupSource
))
.
SetNillableLastLoginAt
(
userIn
.
LastLoginAt
)
.
SetNillableLastLoginAt
(
userIn
.
LastLoginAt
)
.
SetNillableLastActiveAt
(
userIn
.
LastActiveAt
)
.
SetNillableLastActiveAt
(
userIn
.
LastActiveAt
)
.
SetRpmLimit
(
userIn
.
RPMLimit
)
.
Save
(
txCtx
)
Save
(
txCtx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
...
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
...
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType
(
userIn
.
BalanceNotifyThresholdType
)
.
SetBalanceNotifyThresholdType
(
userIn
.
BalanceNotifyThresholdType
)
.
SetNillableBalanceNotifyThreshold
(
userIn
.
BalanceNotifyThreshold
)
.
SetNillableBalanceNotifyThreshold
(
userIn
.
BalanceNotifyThreshold
)
.
SetBalanceNotifyExtraEmails
(
marshalExtraEmails
(
userIn
.
BalanceNotifyExtraEmails
))
.
SetBalanceNotifyExtraEmails
(
marshalExtraEmails
(
userIn
.
BalanceNotifyExtraEmails
))
.
SetTotalRecharged
(
userIn
.
TotalRecharged
)
SetTotalRecharged
(
userIn
.
TotalRecharged
)
.
SetRpmLimit
(
userIn
.
RPMLimit
)
if
userIn
.
SignupSource
!=
""
{
if
userIn
.
SignupSource
!=
""
{
updateOp
=
updateOp
.
SetSignupSource
(
userIn
.
SignupSource
)
updateOp
=
updateOp
.
SetSignupSource
(
userIn
.
SignupSource
)
}
}
...
...
backend/internal/repository/user_rpm_cache.go
0 → 100644
View file @
6b0cf466
package
repository
import
(
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
const
(
userGroupRPMKeyPrefix
=
"rpm:ug:"
userRPMKeyPrefix
=
"rpm:u:"
userRPMKeyTTL
=
120
*
time
.
Second
)
type
userRPMCacheImpl
struct
{
rdb
*
redis
.
Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func
NewUserRPMCache
(
rdb
*
redis
.
Client
)
service
.
UserRPMCache
{
return
&
userRPMCacheImpl
{
rdb
:
rdb
}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func
(
c
*
userRPMCacheImpl
)
minuteTS
(
ctx
context
.
Context
)
(
int64
,
error
)
{
t
,
err
:=
c
.
rdb
.
Time
(
ctx
)
.
Result
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"redis TIME: %w"
,
err
)
}
return
t
.
Unix
()
/
60
,
nil
}
// atomicIncr 原子 INCR+EXPIRE。
func
(
c
*
userRPMCacheImpl
)
atomicIncr
(
ctx
context
.
Context
,
key
string
)
(
int
,
error
)
{
pipe
:=
c
.
rdb
.
TxPipeline
()
incr
:=
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
userRPMKeyTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user rpm increment: %w"
,
err
)
}
return
int
(
incr
.
Val
()),
nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func
(
c
*
userRPMCacheImpl
)
IncrementUserGroupRPM
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d:%d"
,
userGroupRPMKeyPrefix
,
userID
,
groupID
,
minute
)
return
c
.
atomicIncr
(
ctx
,
key
)
}
// IncrementUserRPM 递增用户分钟计数。
func
(
c
*
userRPMCacheImpl
)
IncrementUserRPM
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
userRPMKeyPrefix
,
userID
,
minute
)
return
c
.
atomicIncr
(
ctx
,
key
)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
func
(
c
*
userRPMCacheImpl
)
GetUserGroupRPM
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d:%d"
,
userGroupRPMKeyPrefix
,
userID
,
groupID
,
minute
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user group rpm get: %w"
,
err
)
}
return
val
,
nil
}
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
func
(
c
*
userRPMCacheImpl
)
GetUserRPM
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
userRPMKeyPrefix
,
userID
,
minute
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user rpm get: %w"
,
err
)
}
return
val
,
nil
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment