"backend/vscode:/vscode.git/clone" did not exist on "1db32d692b9a39b79725956f1ae88780f7432dbb"
Commit dc5d42ad authored by james-6-23's avatar james-6-23
Browse files

feat(rpm): RPM 限流模块优化

P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
parent ef967d8f
...@@ -110,6 +110,8 @@ type CreateGroupRequest struct { ...@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
RequirePrivacySet bool `json:"require_privacy_set"` RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制)
RPMLimit int `json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定) // 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -145,6 +147,8 @@ type UpdateGroupRequest struct { ...@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
RequirePrivacySet *bool `json:"require_privacy_set"` RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"` DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
RPMLimit *int `json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet, RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel, DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
...@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet, RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel, DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
...@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) { ...@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"}) response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
} }
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
type BatchSetGroupRPMOverridesRequest struct {
Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
}
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
// PUT /api/v1/admin/groups/:id/rpm-overrides
func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRPMOverridesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
}
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
// DELETE /api/v1/admin/groups/:id/rpm-overrides
func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders // UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct { type UpdateSortOrderRequest struct {
Updates []struct { Updates []struct {
......
...@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback, EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelAnthropic: settings.FallbackModelAnthropic,
...@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct { ...@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
...@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON, CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback, EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic, FallbackModelAnthropic: req.FallbackModelAnthropic,
...@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions, DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback, EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
......
...@@ -40,6 +40,7 @@ type CreateUserRequest struct { ...@@ -40,6 +40,7 @@ type CreateUserRequest struct {
Notes string `json:"notes"` Notes string `json:"notes"`
Balance float64 `json:"balance"` Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
RPMLimit int `json:"rpm_limit"`
AllowedGroups []int64 `json:"allowed_groups"` AllowedGroups []int64 `json:"allowed_groups"`
} }
...@@ -52,6 +53,7 @@ type UpdateUserRequest struct { ...@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
Notes *string `json:"notes"` Notes *string `json:"notes"`
Balance *float64 `json:"balance"` Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
RPMLimit *int `json:"rpm_limit"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"` AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置 // GroupRates 用户专属分组倍率配置
...@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) { ...@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
Notes: req.Notes, Notes: req.Notes,
Balance: req.Balance, Balance: req.Balance,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
RPMLimit: req.RPMLimit,
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
}) })
if err != nil { if err != nil {
...@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) { ...@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Notes: req.Notes, Notes: req.Notes,
Balance: req.Balance, Balance: req.Balance,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
RPMLimit: req.RPMLimit,
Status: req.Status, Status: req.Status,
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates, GroupRates: req.GroupRates,
...@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) { ...@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys": result.MigratedKeys, "migrated_keys": result.MigratedKeys,
}) })
} }
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
// GET /api/v1/admin/users/:id/rpm-status
func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, status)
}
...@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User { ...@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
BalanceNotifyThreshold: u.BalanceNotifyThreshold, BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails), BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged, TotalRecharged: u.TotalRecharged,
RPMLimit: u.RPMLimit,
} }
} }
...@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group { ...@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
AllowMessagesDispatch: g.AllowMessagesDispatch, AllowMessagesDispatch: g.AllowMessagesDispatch,
RequireOAuthOnly: g.RequireOAuthOnly, RequireOAuthOnly: g.RequireOAuthOnly,
RequirePrivacySet: g.RequirePrivacySet, RequirePrivacySet: g.RequirePrivacySet,
RPMLimit: g.RPMLimit,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -108,6 +108,7 @@ type SystemSettings struct { ...@@ -108,6 +108,7 @@ type SystemSettings struct {
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration // Model fallback configuration
......
...@@ -26,6 +26,9 @@ type User struct { ...@@ -26,6 +26,9 @@ type User struct {
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"` TotalRecharged float64 `json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
RPMLimit int `json:"rpm_limit"`
APIKeys []APIKey `json:"api_keys,omitempty"` APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
} }
...@@ -108,6 +111,9 @@ type Group struct { ...@@ -108,6 +111,9 @@ type Group struct {
RequireOAuthOnly bool `json:"require_oauth_only"` RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"` RequirePrivacySet bool `json:"require_privacy_set"`
// RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
RPMLimit int `json:"rpm_limit"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
......
...@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅 // 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
...@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
...@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额) // 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额 // 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.errorResponse(c, status, code, message) h.errorResponse(c, status, code, message)
return return
} }
...@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter ...@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)
} }
func billingErrorDetails(err error) (status int, code, message string) { func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
if errors.Is(err, service.ErrBillingServiceUnavailable) { if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
if msg == "" { if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later." msg = "Billing service temporarily unavailable. Please retry later."
} }
return http.StatusServiceUnavailable, "billing_service_error", msg return http.StatusServiceUnavailable, "billing_service_error", msg, 0
} }
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
} }
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
} }
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
// 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
msg := pkgerrors.Message(err)
retrySeconds := 60 - int(time.Now().Unix()%60)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
} }
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
if msg == "" { if msg == "" {
...@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) { ...@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
).Warn("gateway.billing_error_missing_message") ).Warn("gateway.billing_error_missing_message")
msg = "Billing error" msg = "Billing error"
} }
return http.StatusForbidden, "billing_error", msg return http.StatusForbidden, "billing_error", msg, 0
} }
func (h *GatewayHandler) metadataBridgeEnabled() bool { func (h *GatewayHandler) metadataBridgeEnabled() bool {
......
package handler
import (
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
require.Equal(t, http.StatusTooManyRequests, status)
require.Equal(t, "rate_limit_exceeded", code)
require.NotEmpty(t, msg)
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
require.LessOrEqual(t, retryAfter, 60)
}
func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
require.Equal(t, http.StatusTooManyRequests, status)
require.Equal(t, "rate_limit_exceeded", code)
require.NotEmpty(t, msg)
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
require.LessOrEqual(t, retryAfter, 60)
}
func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
for _, err := range []error{
service.ErrAPIKeyRateLimit5hExceeded,
service.ErrAPIKeyRateLimit1dExceeded,
service.ErrAPIKeyRateLimit7dExceeded,
} {
status, code, _, _ := billingErrorDetails(err)
require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
require.Equal(t, "rate_limit_exceeded", code)
}
}
func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
require.Equal(t, http.StatusServiceUnavailable, status)
require.Equal(t, "billing_service_error", code)
require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
}
func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
require.Equal(t, http.StatusForbidden, status)
require.Equal(t, "billing_error", code)
require.NotEmpty(t, msg)
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"strconv"
"time" "time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 2. Re-check billing // 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.chatCompletionsErrorResponse(c, status, code, message) h.chatCompletionsErrorResponse(c, status, code, message)
return return
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"strconv"
"time" "time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) { ...@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing // 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.responsesErrorResponse(c, status, code, message) h.responsesErrorResponse(c, status, code, message)
return return
} }
......
...@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi ...@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"regexp" "regexp"
"strconv"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/domain"
...@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait) // 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err) status, _, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
googleError(c, status, message) googleError(c, status, message)
return return
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"strconv"
"time" "time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
...@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
......
...@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
...@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.anthropicStreamingAwareError(c, status, code, message, streamStarted) h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
...@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { ...@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
......
...@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldSignupSource, user.FieldSignupSource,
user.FieldLastLoginAt, user.FieldLastLoginAt,
user.FieldLastActiveAt, user.FieldLastActiveAt,
user.FieldRpmLimit,
) )
}). }).
WithGroup(func(q *dbent.GroupQuery) { WithGroup(func(q *dbent.GroupQuery) {
...@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldAllowMessagesDispatch, group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel, group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig, group.FieldMessagesDispatchModelConfig,
group.FieldRpmLimit,
) )
}). }).
Only(ctx) Only(ctx)
...@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User { ...@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold, BalanceNotifyThreshold: u.BalanceNotifyThreshold,
TotalRecharged: u.TotalRecharged, TotalRecharged: u.TotalRecharged,
RPMLimit: u.RpmLimit,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt, UpdatedAt: u.UpdatedAt,
} }
...@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet: g.RequirePrivacySet, RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
RPMLimit: g.RpmLimit,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel). SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetRpmLimit(groupIn.RPMLimit)
// 设置模型路由配置 // 设置模型路由配置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
...@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel). SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetRpmLimit(groupIn.RPMLimit)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if groupIn.DailyLimitUSD != nil { if groupIn.DailyLimitUSD != nil {
......
...@@ -13,14 +13,14 @@ type userGroupRateRepository struct { ...@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql sqlExecutor sql sqlExecutor
} }
// NewUserGroupRateRepository 创建用户专属分组倍率仓储 // NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB} return &userGroupRateRepository{sql: sqlDB}
} }
// GetByUserID 获取用户所有专属分组倍率 // GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
rows, err := r.sql.QueryContext(ctx, query, userID) rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) ...@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil return result, nil
} }
// GetByUserIDs 批量获取多个用户的专属分组倍率。 // GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
// 返回结构:map[userID]map[groupID]rate
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs)) result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 { if len(userIDs) == 0 {
...@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in ...@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows, err := r.sql.QueryContext(ctx, ` rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers FROM user_group_rate_multipliers
WHERE user_id = ANY($1) WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
`, pq.Array(uniqueIDs)) `, pq.Array(uniqueIDs))
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in ...@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil return result, nil
} }
// GetByGroupID 获取指定分组下所有用户的专属倍率 // GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := ` query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
FROM user_group_rate_multipliers ugr FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1 WHERE ugr.group_id = $1
...@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 ...@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry var result []service.UserGroupRateEntry
for rows.Next() { for rows.Next() {
var entry service.UserGroupRateEntry var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil { var rate sql.NullFloat64
var rpm sql.NullInt32
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
return nil, err return nil, err
} }
if rate.Valid {
v := rate.Float64
entry.RateMultiplier = &v
}
if rpm.Valid {
v := int(rpm.Int32)
entry.RPMOverride = &v
}
result = append(result, entry) result = append(result, entry)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
...@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 ...@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return result, nil return result, nil
} }
// GetByUserAndGroup 获取用户在特定分组的专属倍率 // GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64 var rate sql.NullFloat64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
...@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, ...@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &rate, nil if !rate.Valid {
return nil, nil
}
v := rate.Float64
return &v, nil
}
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rpm sql.NullInt32
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
if !rpm.Valid {
return nil, nil
}
v := int(rpm.Int32)
return &v, nil
} }
// SyncUserGroupRates 同步用户的分组专属倍率 // SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 { if len(rates) == 0 {
// 如果传入空 map,删除该用户的所有专属倍率 if _, err := r.sql.ExecContext(ctx, `
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`, userID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID)
return err return err
} }
// 分离需要删除和需要 upsert 的记录 var clearGroupIDs []int64
var toDelete []int64
upsertGroupIDs := make([]int64, 0, len(rates)) upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates)) upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates { for groupID, rate := range rates {
if rate == nil { if rate == nil {
toDelete = append(toDelete, groupID) clearGroupIDs = append(clearGroupIDs, groupID)
} else { } else {
upsertGroupIDs = append(upsertGroupIDs, groupID) upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate) upsertRates = append(upsertRates, *rate)
} }
} }
// 删除指定的记录 if len(clearGroupIDs) > 0 {
if len(toDelete) > 0 { if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`, userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID, pq.Array(toDelete)); err != nil { userID, pq.Array(clearGroupIDs)); err != nil {
return err return err
} }
} }
// Upsert 记录
now := time.Now()
if len(upsertGroupIDs) > 0 { if len(upsertGroupIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, ` _, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT SELECT
...@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID ...@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil return nil
} }
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插) // SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
// 语义:
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行:upsert rate_multiplier。
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error { func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil { keepUserIDs := make([]int64, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err return err
} }
if len(entries) == 0 { if len(entries) == 0 {
return nil return nil
} }
userIDs := make([]int64, len(entries)) userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries)) rates := make([]float64, len(entries))
for i, e := range entries { for i, e := range entries {
...@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, ...@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return err return err
} }
// DeleteByGroupID 删除指定分组的所有用户专属倍率 // SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
// 语义:
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
keepUserIDs := make([]int64, 0, len(entries))
var clearUserIDs []int64
upsertUserIDs := make([]int64, 0, len(entries))
upsertValues := make([]int32, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
if e.RPMOverride == nil {
clearUserIDs = append(clearUserIDs, e.UserID)
} else {
upsertUserIDs = append(upsertUserIDs, e.UserID)
upsertValues = append(upsertValues, int32(*e.RPMOverride))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 显式 clear 的行。
if len(clearUserIDs) > 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`, groupID, pq.Array(clearUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err
}
if len(upsertUserIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
if err != nil {
return err
}
}
return nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID)
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err return err
} }
// DeleteByUserID 删除指定用户的所有专属倍率 // DeleteByUserID 删除指定用户的所有专属条目
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err return err
......
...@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt). SetNillableLastActiveAt(userIn.LastActiveAt).
SetRpmLimit(userIn.RPMLimit).
Save(txCtx) Save(txCtx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
...@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged) SetTotalRecharged(userIn.TotalRecharged).
SetRpmLimit(userIn.RPMLimit)
if userIn.SignupSource != "" { if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource) updateOp = updateOp.SetSignupSource(userIn.SignupSource)
} }
......
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
const (
userGroupRPMKeyPrefix = "rpm:ug:"
userRPMKeyPrefix = "rpm:u:"
userRPMKeyTTL = 120 * time.Second
)
type userRPMCacheImpl struct {
rdb *redis.Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
return &userRPMCacheImpl{rdb: rdb}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
t, err := c.rdb.Time(ctx).Result()
if err != nil {
return 0, fmt.Errorf("redis TIME: %w", err)
}
return t.Unix() / 60, nil
}
// atomicIncr 原子 INCR+EXPIRE。
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
pipe := c.rdb.TxPipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, userRPMKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("user rpm increment: %w", err)
}
return int(incr.Val()), nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
return c.atomicIncr(ctx, key)
}
// IncrementUserRPM 递增用户分钟计数。
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
return c.atomicIncr(ctx, key)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user group rpm get: %w", err)
}
return val, nil
}
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user rpm get: %w", err)
}
return val, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment