Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6b0cf466
Unverified
Commit
6b0cf466
authored
Apr 23, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 23, 2026
Browse files
Merge pull request #1815 from james-6-23/feat_rpm
feat(rpm): RPM 限流模块优化
parents
ef967d8f
dc5d42ad
Changes
79
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/admin/group_handler.go
View file @
6b0cf466
...
...
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
RequirePrivacySet
bool
`json:"require_privacy_set"`
DefaultMappedModel
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制)
RPMLimit
int
`json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
RequirePrivacySet
*
bool
`json:"require_privacy_set"`
DefaultMappedModel
*
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
*
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
RPMLimit
*
int
`json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet
:
req
.
RequirePrivacySet
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
RPMLimit
:
req
.
RPMLimit
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet
:
req
.
RequirePrivacySet
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
RPMLimit
:
req
.
RPMLimit
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Rate multipliers updated successfully"
})
}
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
type
BatchSetGroupRPMOverridesRequest
struct
{
Entries
[]
service
.
GroupRPMOverrideInput
`json:"entries" binding:"required"`
}
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
// PUT /api/v1/admin/groups/:id/rpm-overrides
func
(
h
*
GroupHandler
)
BatchSetGroupRPMOverrides
(
c
*
gin
.
Context
)
{
groupID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid group ID"
)
return
}
var
req
BatchSetGroupRPMOverridesRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
err
:=
h
.
adminService
.
BatchSetGroupRPMOverrides
(
c
.
Request
.
Context
(),
groupID
,
req
.
Entries
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"RPM overrides updated successfully"
})
}
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
// DELETE /api/v1/admin/groups/:id/rpm-overrides
func
(
h
*
GroupHandler
)
ClearGroupRPMOverrides
(
c
*
gin
.
Context
)
{
groupID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid group ID"
)
return
}
if
err
:=
h
.
adminService
.
ClearGroupRPMOverrides
(
c
.
Request
.
Context
(),
groupID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"RPM overrides cleared successfully"
})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type
UpdateSortOrderRequest
struct
{
Updates
[]
struct
{
...
...
backend/internal/handler/admin/setting_handler.go
View file @
6b0cf466
...
...
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
settings
.
CustomEndpoints
),
DefaultConcurrency
:
settings
.
DefaultConcurrency
,
DefaultBalance
:
settings
.
DefaultBalance
,
DefaultUserRPMLimit
:
settings
.
DefaultUserRPMLimit
,
DefaultSubscriptions
:
defaultSubscriptions
,
EnableModelFallback
:
settings
.
EnableModelFallback
,
FallbackModelAnthropic
:
settings
.
FallbackModelAnthropic
,
...
...
@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
// 默认配置
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
DefaultUserRPMLimit
int
`json:"default_user_rpm_limit"`
DefaultSubscriptions
[]
dto
.
DefaultSubscriptionSetting
`json:"default_subscriptions"`
AuthSourceDefaultEmailBalance
*
float64
`json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency
*
int
`json:"auth_source_default_email_concurrency"`
...
...
@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints
:
customEndpointsJSON
,
DefaultConcurrency
:
req
.
DefaultConcurrency
,
DefaultBalance
:
req
.
DefaultBalance
,
DefaultUserRPMLimit
:
req
.
DefaultUserRPMLimit
,
DefaultSubscriptions
:
defaultSubscriptions
,
EnableModelFallback
:
req
.
EnableModelFallback
,
FallbackModelAnthropic
:
req
.
FallbackModelAnthropic
,
...
...
@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
updatedSettings
.
CustomEndpoints
),
DefaultConcurrency
:
updatedSettings
.
DefaultConcurrency
,
DefaultBalance
:
updatedSettings
.
DefaultBalance
,
DefaultUserRPMLimit
:
updatedSettings
.
DefaultUserRPMLimit
,
DefaultSubscriptions
:
updatedDefaultSubscriptions
,
EnableModelFallback
:
updatedSettings
.
EnableModelFallback
,
FallbackModelAnthropic
:
updatedSettings
.
FallbackModelAnthropic
,
...
...
backend/internal/handler/admin/user_handler.go
View file @
6b0cf466
...
...
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
Notes
string
`json:"notes"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
RPMLimit
int
`json:"rpm_limit"`
AllowedGroups
[]
int64
`json:"allowed_groups"`
}
...
...
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
Notes
*
string
`json:"notes"`
Balance
*
float64
`json:"balance"`
Concurrency
*
int
`json:"concurrency"`
RPMLimit
*
int
`json:"rpm_limit"`
Status
string
`json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups
*
[]
int64
`json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
...
...
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
RPMLimit
:
req
.
RPMLimit
,
AllowedGroups
:
req
.
AllowedGroups
,
})
if
err
!=
nil
{
...
...
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
RPMLimit
:
req
.
RPMLimit
,
Status
:
req
.
Status
,
AllowedGroups
:
req
.
AllowedGroups
,
GroupRates
:
req
.
GroupRates
,
...
...
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys"
:
result
.
MigratedKeys
,
})
}
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
// GET /api/v1/admin/users/:id/rpm-status
func
(
h
*
UserHandler
)
GetUserRPMStatus
(
c
*
gin
.
Context
)
{
userID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid user ID"
)
return
}
status
,
err
:=
h
.
adminService
.
GetUserRPMStatus
(
c
.
Request
.
Context
(),
userID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
status
)
}
backend/internal/handler/dto/mappers.go
View file @
6b0cf466
...
...
@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
NotifyEmailEntriesFromService
(
u
.
BalanceNotifyExtraEmails
),
TotalRecharged
:
u
.
TotalRecharged
,
RPMLimit
:
u
.
RPMLimit
,
}
}
...
...
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
AllowMessagesDispatch
:
g
.
AllowMessagesDispatch
,
RequireOAuthOnly
:
g
.
RequireOAuthOnly
,
RequirePrivacySet
:
g
.
RequirePrivacySet
,
RPMLimit
:
g
.
RPMLimit
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/handler/dto/settings.go
View file @
6b0cf466
...
...
@@ -108,6 +108,7 @@ type SystemSettings struct {
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
DefaultUserRPMLimit
int
`json:"default_user_rpm_limit"`
DefaultSubscriptions
[]
DefaultSubscriptionSetting
`json:"default_subscriptions"`
// Model fallback configuration
...
...
backend/internal/handler/dto/types.go
View file @
6b0cf466
...
...
@@ -26,6 +26,9 @@ type User struct {
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails"`
TotalRecharged
float64
`json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
RPMLimit
int
`json:"rpm_limit"`
APIKeys
[]
APIKey
`json:"api_keys,omitempty"`
Subscriptions
[]
UserSubscription
`json:"subscriptions,omitempty"`
}
...
...
@@ -108,6 +111,9 @@ type Group struct {
RequireOAuthOnly
bool
`json:"require_oauth_only"`
RequirePrivacySet
bool
`json:"require_privacy_set"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
RPMLimit
int
`json:"rpm_limit"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
...
...
backend/internal/handler/gateway_handler.go
View file @
6b0cf466
...
...
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gateway.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
fallbackAPIKey
:=
cloneAPIKeyWithGroup
(
apiKey
,
fallbackGroup
)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
fallbackAPIKey
.
User
,
fallbackAPIKey
,
fallbackGroup
,
nil
);
err
!=
nil
{
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
errorResponse
(
c
,
status
,
code
,
message
)
return
}
...
...
@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c
.
JSON
(
http
.
StatusOK
,
response
)
}
func
billingErrorDetails
(
err
error
)
(
status
int
,
code
,
message
string
)
{
func
billingErrorDetails
(
err
error
)
(
status
int
,
code
,
message
string
,
retryAfter
int
)
{
if
errors
.
Is
(
err
,
service
.
ErrBillingServiceUnavailable
)
{
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
msg
=
"Billing service temporarily unavailable. Please retry later."
}
return
http
.
StatusServiceUnavailable
,
"billing_service_error"
,
msg
return
http
.
StatusServiceUnavailable
,
"billing_service_error"
,
msg
,
0
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit5hExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
0
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit1dExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
0
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit7dExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
0
}
// 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
if
errors
.
Is
(
err
,
service
.
ErrGroupRPMExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrUserRPMExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
retrySeconds
:=
60
-
int
(
time
.
Now
()
.
Unix
()
%
60
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
,
retrySeconds
}
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
...
...
@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
)
.
Warn
(
"gateway.billing_error_missing_message"
)
msg
=
"Billing error"
}
return
http
.
StatusForbidden
,
"billing_error"
,
msg
return
http
.
StatusForbidden
,
"billing_error"
,
msg
,
0
}
func
(
h
*
GatewayHandler
)
metadataBridgeEnabled
()
bool
{
...
...
backend/internal/handler/gateway_handler_billing_error_test.go
0 → 100644
View file @
6b0cf466
package
handler
import
(
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests
(
t
*
testing
.
T
)
{
status
,
code
,
msg
,
retryAfter
:=
billingErrorDetails
(
service
.
ErrGroupRPMExceeded
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
status
)
require
.
Equal
(
t
,
"rate_limit_exceeded"
,
code
)
require
.
NotEmpty
(
t
,
msg
)
require
.
Greater
(
t
,
retryAfter
,
0
,
"RPM exceeded should return positive Retry-After"
)
require
.
LessOrEqual
(
t
,
retryAfter
,
60
)
}
func
TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests
(
t
*
testing
.
T
)
{
status
,
code
,
msg
,
retryAfter
:=
billingErrorDetails
(
service
.
ErrUserRPMExceeded
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
status
)
require
.
Equal
(
t
,
"rate_limit_exceeded"
,
code
)
require
.
NotEmpty
(
t
,
msg
)
require
.
Greater
(
t
,
retryAfter
,
0
,
"RPM exceeded should return positive Retry-After"
)
require
.
LessOrEqual
(
t
,
retryAfter
,
60
)
}
func
TestBillingErrorDetails_APIKeyRateLimitStillMaps
(
t
*
testing
.
T
)
{
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
for
_
,
err
:=
range
[]
error
{
service
.
ErrAPIKeyRateLimit5hExceeded
,
service
.
ErrAPIKeyRateLimit1dExceeded
,
service
.
ErrAPIKeyRateLimit7dExceeded
,
}
{
status
,
code
,
_
,
_
:=
billingErrorDetails
(
err
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
status
,
"status for %v"
,
err
)
require
.
Equal
(
t
,
"rate_limit_exceeded"
,
code
)
}
}
func
TestBillingErrorDetails_BillingServiceUnavailableMapsTo503
(
t
*
testing
.
T
)
{
status
,
code
,
_
,
retryAfter
:=
billingErrorDetails
(
service
.
ErrBillingServiceUnavailable
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
status
)
require
.
Equal
(
t
,
"billing_service_error"
,
code
)
require
.
Equal
(
t
,
0
,
retryAfter
,
"non-RPM errors should not set Retry-After"
)
}
func
TestBillingErrorDetails_UnknownErrorFallsBackTo403
(
t
*
testing
.
T
)
{
status
,
code
,
msg
,
_
:=
billingErrorDetails
(
service
.
ErrInsufficientBalance
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
status
)
require
.
Equal
(
t
,
"billing_error"
,
code
)
require
.
NotEmpty
(
t
,
msg
)
}
backend/internal/handler/gateway_handler_chat_completions.go
View file @
6b0cf466
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"time"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...
...
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 2. Re-check billing
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gateway.cc.billing_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
chatCompletionsErrorResponse
(
c
,
status
,
code
,
message
)
return
}
...
...
backend/internal/handler/gateway_handler_responses.go
View file @
6b0cf466
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"time"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...
...
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gateway.responses.billing_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
responsesErrorResponse
(
c
,
status
,
code
,
message
)
return
}
...
...
backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
View file @
6b0cf466
...
...
@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
billingCacheSvc
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
cfg
)
billingCacheSvc
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
concurrencySvc
:=
service
.
NewConcurrencyService
(
&
fakeConcurrencyCache
{})
concurrencyHelper
:=
NewConcurrencyHelper
(
concurrencySvc
,
SSEPingFormatClaude
,
0
)
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
6b0cf466
...
...
@@ -9,6 +9,7 @@ import (
"errors"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
...
...
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"gemini.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
_
,
message
:=
billingErrorDetails
(
err
)
status
,
_
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
googleError
(
c
,
status
,
message
)
return
}
...
...
backend/internal/handler/openai_chat_completions.go
View file @
6b0cf466
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"time"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...
...
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai_chat_completions.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
6b0cf466
...
...
@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai_messages.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
anthropicStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
backend/internal/handler/openai_images.go
View file @
6b0cf466
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"time"
...
...
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"openai.images.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
status
,
code
,
message
,
retryAfter
:=
billingErrorDetails
(
err
)
if
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
backend/internal/repository/api_key_repo.go
View file @
6b0cf466
...
...
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user
.
FieldSignupSource
,
user
.
FieldLastLoginAt
,
user
.
FieldLastActiveAt
,
user
.
FieldRpmLimit
,
)
})
.
WithGroup
(
func
(
q
*
dbent
.
GroupQuery
)
{
...
...
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldAllowMessagesDispatch
,
group
.
FieldDefaultMappedModel
,
group
.
FieldMessagesDispatchModelConfig
,
group
.
FieldRpmLimit
,
)
})
.
Only
(
ctx
)
...
...
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
BalanceNotifyThresholdType
:
u
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
TotalRecharged
:
u
.
TotalRecharged
,
RPMLimit
:
u
.
RpmLimit
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
}
...
...
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet
:
g
.
RequirePrivacySet
,
DefaultMappedModel
:
g
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
g
.
MessagesDispatchModelConfig
,
RPMLimit
:
g
.
RpmLimit
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/repository/group_repo.go
View file @
6b0cf466
...
...
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetDefaultMappedModel
(
groupIn
.
DefaultMappedModel
)
.
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
.
SetRpmLimit
(
groupIn
.
RPMLimit
)
// 设置模型路由配置
if
groupIn
.
ModelRouting
!=
nil
{
...
...
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetDefaultMappedModel
(
groupIn
.
DefaultMappedModel
)
.
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
SetMessagesDispatchModelConfig
(
groupIn
.
MessagesDispatchModelConfig
)
.
SetRpmLimit
(
groupIn
.
RPMLimit
)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if
groupIn
.
DailyLimitUSD
!=
nil
{
...
...
backend/internal/repository/user_group_rate_repo.go
View file @
6b0cf466
...
...
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql
sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
// NewUserGroupRateRepository 创建用户专属分组倍率
/RPM
仓储
func
NewUserGroupRateRepository
(
sqlDB
*
sql
.
DB
)
service
.
UserGroupRateRepository
{
return
&
userGroupRateRepository
{
sql
:
sqlDB
}
}
// GetByUserID 获取用户
的
所有专属分组
倍率
// GetByUserID 获取用户所有专属分组
rate_multiplier(仅返回非 NULL 的条目)
func
(
r
*
userGroupRateRepository
)
GetByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
map
[
int64
]
float64
,
error
)
{
query
:=
`SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
query
:=
`SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1
AND rate_multiplier IS NOT NULL
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
userID
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return
result
,
nil
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构:map[userID]map[groupID]rate
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
func
(
r
*
userGroupRateRepository
)
GetByUserIDs
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
map
[
int64
]
float64
,
error
)
{
result
:=
make
(
map
[
int64
]
map
[
int64
]
float64
,
len
(
userIDs
))
if
len
(
userIDs
)
==
0
{
...
...
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
WHERE user_id = ANY($1)
AND rate_multiplier IS NOT NULL
`
,
pq
.
Array
(
uniqueIDs
))
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return
result
,
nil
}
// GetByGroupID 获取指定分组下所有用户的专属
倍率
// GetByGroupID 获取指定分组下所有用户的专属
配置(rate 与 rpm_override 任一非 NULL 即返回)
func
(
r
*
userGroupRateRepository
)
GetByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
UserGroupRateEntry
,
error
)
{
query
:=
`
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
, ugr.rpm_override
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
...
...
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var
result
[]
service
.
UserGroupRateEntry
for
rows
.
Next
()
{
var
entry
service
.
UserGroupRateEntry
if
err
:=
rows
.
Scan
(
&
entry
.
UserID
,
&
entry
.
UserName
,
&
entry
.
UserEmail
,
&
entry
.
UserNotes
,
&
entry
.
UserStatus
,
&
entry
.
RateMultiplier
);
err
!=
nil
{
var
rate
sql
.
NullFloat64
var
rpm
sql
.
NullInt32
if
err
:=
rows
.
Scan
(
&
entry
.
UserID
,
&
entry
.
UserName
,
&
entry
.
UserEmail
,
&
entry
.
UserNotes
,
&
entry
.
UserStatus
,
&
rate
,
&
rpm
);
err
!=
nil
{
return
nil
,
err
}
if
rate
.
Valid
{
v
:=
rate
.
Float64
entry
.
RateMultiplier
=
&
v
}
if
rpm
.
Valid
{
v
:=
int
(
rpm
.
Int32
)
entry
.
RPMOverride
=
&
v
}
result
=
append
(
result
,
entry
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
...
...
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return
result
,
nil
}
// GetByUserAndGroup 获取用户在特定分组的专属
倍率
// GetByUserAndGroup 获取用户在特定分组的专属
rate_multiplier(NULL 返回 nil)
func
(
r
*
userGroupRateRepository
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
query
:=
`SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var
rate
f
loat64
var
rate
sql
.
NullF
loat64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rate
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
...
...
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if
err
!=
nil
{
return
nil
,
err
}
return
&
rate
,
nil
if
!
rate
.
Valid
{
return
nil
,
nil
}
v
:=
rate
.
Float64
return
&
v
,
nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
func
(
r
*
userGroupRateRepository
)
GetRPMOverrideByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
int
,
error
)
{
query
:=
`SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var
rpm
sql
.
NullInt32
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
[]
any
{
userID
,
groupID
},
&
rpm
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
}
if
err
!=
nil
{
return
nil
,
err
}
if
!
rpm
.
Valid
{
return
nil
,
nil
}
v
:=
int
(
rpm
.
Int32
)
return
&
v
,
nil
}
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func
(
r
*
userGroupRateRepository
)
SyncUserGroupRates
(
ctx
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
if
len
(
rates
)
==
0
{
// 如果传入空 map,删除该用户的所有专属倍率
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`
,
userID
);
err
!=
nil
{
return
err
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`
,
userID
)
return
err
}
// 分离需要删除和需要 upsert 的记录
var
toDelete
[]
int64
var
clearGroupIDs
[]
int64
upsertGroupIDs
:=
make
([]
int64
,
0
,
len
(
rates
))
upsertRates
:=
make
([]
float64
,
0
,
len
(
rates
))
for
groupID
,
rate
:=
range
rates
{
if
rate
==
nil
{
toDelete
=
append
(
toDelete
,
groupID
)
clearGroupIDs
=
append
(
clearGroupIDs
,
groupID
)
}
else
{
upsertGroupIDs
=
append
(
upsertGroupIDs
,
groupID
)
upsertRates
=
append
(
upsertRates
,
*
rate
)
}
}
// 删除指定的记录
if
len
(
toDelete
)
>
0
{
if
len
(
clearGroupIDs
)
>
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`
,
userID
,
pq
.
Array
(
clearGroupIDs
));
err
!=
nil
{
return
err
}
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`
,
userID
,
pq
.
Array
(
toDelete
));
err
!=
nil
{
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)
AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
userID
,
pq
.
Array
(
clearGroupIDs
));
err
!=
nil
{
return
err
}
}
// Upsert 记录
now
:=
time
.
Now
()
if
len
(
upsertGroupIDs
)
>
0
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
...
...
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return
nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
// 语义:
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行:upsert rate_multiplier。
func
(
r
*
userGroupRateRepository
)
SyncGroupRateMultipliers
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRateMultiplierInput
)
error
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
);
err
!=
nil
{
keepUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
keepUserIDs
=
append
(
keepUserIDs
,
e
.
UserID
)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if
len
(
keepUserIDs
)
==
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
}
else
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`
,
groupID
,
pq
.
Array
(
keepUserIDs
));
err
!=
nil
{
return
err
}
}
// 清空后若整行 NULL 则删除。
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
);
err
!=
nil
{
return
err
}
if
len
(
entries
)
==
0
{
return
nil
}
userIDs
:=
make
([]
int64
,
len
(
entries
))
rates
:=
make
([]
float64
,
len
(
entries
))
for
i
,
e
:=
range
entries
{
...
...
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return
err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
// 语义:
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func
(
r
*
userGroupRateRepository
)
SyncGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
,
entries
[]
service
.
GroupRPMOverrideInput
)
error
{
keepUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
var
clearUserIDs
[]
int64
upsertUserIDs
:=
make
([]
int64
,
0
,
len
(
entries
))
upsertValues
:=
make
([]
int32
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
keepUserIDs
=
append
(
keepUserIDs
,
e
.
UserID
)
if
e
.
RPMOverride
==
nil
{
clearUserIDs
=
append
(
clearUserIDs
,
e
.
UserID
)
}
else
{
upsertUserIDs
=
append
(
upsertUserIDs
,
e
.
UserID
)
upsertValues
=
append
(
upsertValues
,
int32
(
*
e
.
RPMOverride
))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if
len
(
keepUserIDs
)
==
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
}
else
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`
,
groupID
,
pq
.
Array
(
keepUserIDs
));
err
!=
nil
{
return
err
}
}
// 显式 clear 的行。
if
len
(
clearUserIDs
)
>
0
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`
,
groupID
,
pq
.
Array
(
clearUserIDs
));
err
!=
nil
{
return
err
}
}
// 清空后若整行 NULL 则删除。
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
);
err
!=
nil
{
return
err
}
if
len
(
upsertUserIDs
)
>
0
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`
,
groupID
,
now
,
pq
.
Array
(
upsertUserIDs
),
pq
.
Array
(
upsertValues
))
if
err
!=
nil
{
return
err
}
}
return
nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func
(
r
*
userGroupRateRepository
)
ClearGroupRPMOverrides
(
ctx
context
.
Context
,
groupID
int64
)
error
{
if
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`
,
groupID
);
err
!=
nil
{
return
err
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`
,
groupID
)
return
err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func
(
r
*
userGroupRateRepository
)
DeleteByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE group_id = $1`
,
groupID
)
return
err
}
// DeleteByUserID 删除指定用户的所有专属
倍率
// DeleteByUserID 删除指定用户的所有专属
条目
func
(
r
*
userGroupRateRepository
)
DeleteByUserID
(
ctx
context
.
Context
,
userID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1`
,
userID
)
return
err
...
...
backend/internal/repository/user_repo.go
View file @
6b0cf466
...
...
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource
(
userSignupSourceOrDefault
(
userIn
.
SignupSource
))
.
SetNillableLastLoginAt
(
userIn
.
LastLoginAt
)
.
SetNillableLastActiveAt
(
userIn
.
LastActiveAt
)
.
SetRpmLimit
(
userIn
.
RPMLimit
)
.
Save
(
txCtx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
...
...
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType
(
userIn
.
BalanceNotifyThresholdType
)
.
SetNillableBalanceNotifyThreshold
(
userIn
.
BalanceNotifyThreshold
)
.
SetBalanceNotifyExtraEmails
(
marshalExtraEmails
(
userIn
.
BalanceNotifyExtraEmails
))
.
SetTotalRecharged
(
userIn
.
TotalRecharged
)
SetTotalRecharged
(
userIn
.
TotalRecharged
)
.
SetRpmLimit
(
userIn
.
RPMLimit
)
if
userIn
.
SignupSource
!=
""
{
updateOp
=
updateOp
.
SetSignupSource
(
userIn
.
SignupSource
)
}
...
...
backend/internal/repository/user_rpm_cache.go
0 → 100644
View file @
6b0cf466
package
repository
import
(
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
const
(
userGroupRPMKeyPrefix
=
"rpm:ug:"
userRPMKeyPrefix
=
"rpm:u:"
userRPMKeyTTL
=
120
*
time
.
Second
)
type
userRPMCacheImpl
struct
{
rdb
*
redis
.
Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func
NewUserRPMCache
(
rdb
*
redis
.
Client
)
service
.
UserRPMCache
{
return
&
userRPMCacheImpl
{
rdb
:
rdb
}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func
(
c
*
userRPMCacheImpl
)
minuteTS
(
ctx
context
.
Context
)
(
int64
,
error
)
{
t
,
err
:=
c
.
rdb
.
Time
(
ctx
)
.
Result
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"redis TIME: %w"
,
err
)
}
return
t
.
Unix
()
/
60
,
nil
}
// atomicIncr 原子 INCR+EXPIRE。
func
(
c
*
userRPMCacheImpl
)
atomicIncr
(
ctx
context
.
Context
,
key
string
)
(
int
,
error
)
{
pipe
:=
c
.
rdb
.
TxPipeline
()
incr
:=
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
userRPMKeyTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user rpm increment: %w"
,
err
)
}
return
int
(
incr
.
Val
()),
nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func
(
c
*
userRPMCacheImpl
)
IncrementUserGroupRPM
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d:%d"
,
userGroupRPMKeyPrefix
,
userID
,
groupID
,
minute
)
return
c
.
atomicIncr
(
ctx
,
key
)
}
// IncrementUserRPM 递增用户分钟计数。
func
(
c
*
userRPMCacheImpl
)
IncrementUserRPM
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
userRPMKeyPrefix
,
userID
,
minute
)
return
c
.
atomicIncr
(
ctx
,
key
)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
func
(
c
*
userRPMCacheImpl
)
GetUserGroupRPM
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d:%d"
,
userGroupRPMKeyPrefix
,
userID
,
groupID
,
minute
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user group rpm get: %w"
,
err
)
}
return
val
,
nil
}
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
func
(
c
*
userRPMCacheImpl
)
GetUserRPM
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
minute
,
err
:=
c
.
minuteTS
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
userRPMKeyPrefix
,
userID
,
minute
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"user rpm get: %w"
,
err
)
}
return
val
,
nil
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment