Commit 2b192f7d authored by shaw's avatar shaw
Browse files

feat: 支持用户专属分组倍率配置

parent 6d0152c8
...@@ -59,8 +59,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -59,8 +59,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
...@@ -100,7 +101,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -100,7 +101,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
...@@ -153,7 +154,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -153,7 +154,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
......
...@@ -170,6 +170,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk ...@@ -170,6 +170,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
...@@ -203,6 +205,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= ...@@ -203,6 +205,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
...@@ -230,6 +234,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr ...@@ -230,6 +234,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
...@@ -252,6 +258,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= ...@@ -252,6 +258,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......
...@@ -45,6 +45,9 @@ type UpdateUserRequest struct { ...@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"` AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
} }
// UpdateBalanceRequest represents balance update request // UpdateBalanceRequest represents balance update request
...@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) { ...@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Status: req.Status, Status: req.Status,
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
......
...@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { ...@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
} }
response.Success(c, out) response.Success(c, out)
} }
// GetUserGroupRates 获取当前用户的专属分组倍率配置
// GET /api/v1/groups/rates
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rates)
}
...@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { ...@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil return nil
} }
return &AdminUser{ return &AdminUser{
User: *base, User: *base,
Notes: u.Notes, Notes: u.Notes,
GroupRates: u.GroupRates,
} }
} }
......
...@@ -29,6 +29,9 @@ type AdminUser struct { ...@@ -29,6 +29,9 @@ type AdminUser struct {
User User
Notes string `json:"notes"` Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
} }
type APIKey struct { type APIKey struct {
......
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户的所有专属分组倍率
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[int64]float64)
for rows.Next() {
var groupID int64
var rate float64
if err := rows.Scan(&groupID, &rate); err != nil {
return nil, err
}
result[groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &rate, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map,删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
if err != nil {
return err
}
}
return nil
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
...@@ -66,6 +66,7 @@ var ProviderSet = wire.NewSet( ...@@ -66,6 +66,7 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository, NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository, NewUserAttributeValueRepository,
NewUserGroupRateRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
......
...@@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps {
} }
userService := service.NewUserService(userRepo, nil) userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
...@@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......
...@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService ...@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
nil, // userSubRepo nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache nil, // cache
&config.Config{}, &config.Config{},
) )
...@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { ...@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
&config.Config{RunMode: config.RunModeSimple}, &config.Config{RunMode: config.RunModeSimple},
) )
......
...@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
...@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
...@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
} }
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
...@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { ...@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
} }
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
......
...@@ -49,6 +49,7 @@ func RegisterUserRoutes( ...@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups") groups := authenticated.Group("/groups")
{ {
groups.GET("/available", h.APIKey.GetAvailableGroups) groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
} }
// 使用记录 // 使用记录
......
...@@ -93,6 +93,9 @@ type UpdateUserInput struct { ...@@ -93,6 +93,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
} }
type CreateGroupInput struct { type CreateGroupInput struct {
...@@ -293,6 +296,7 @@ type adminServiceImpl struct { ...@@ -293,6 +296,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
...@@ -307,6 +311,7 @@ func NewAdminService( ...@@ -307,6 +311,7 @@ func NewAdminService(
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
...@@ -319,6 +324,7 @@ func NewAdminService( ...@@ -319,6 +324,7 @@ func NewAdminService(
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
...@@ -333,11 +339,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi ...@@ -333,11 +339,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil return users, result.Total, nil
} }
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
} }
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
...@@ -406,6 +436,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -406,6 +436,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
...@@ -941,6 +979,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -941,6 +979,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil { if err != nil {
return err return err
} }
// 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存 // 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil { if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
......
...@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct { ...@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务 // APIKeyService API Key服务
type APIKeyService struct { type APIKeyService struct {
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache APIKeyCache userGroupRateRepo UserGroupRateRepository
cfg *config.Config cache APIKeyCache
authCacheL1 *ristretto.Cache cfg *config.Config
authCfg apiKeyAuthCacheConfig authCacheL1 *ristretto.Cache
authGroup singleflight.Group authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
} }
// NewAPIKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
...@@ -132,16 +133,18 @@ func NewAPIKeyService( ...@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache, cache APIKeyCache,
cfg *config.Config, cfg *config.Config,
) *APIKeyService { ) *APIKeyService {
svc := &APIKeyService{ svc := &APIKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, userGroupRateRepo: userGroupRateRepo,
cfg: cfg, cache: cache,
cfg: cfg,
} }
svc.initAuthCache(cfg) svc.initAuthCache(cfg)
return svc return svc
...@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword ...@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil return keys, nil
} }
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) // CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid // Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
......
...@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9) groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{ cacheEntry := &APIKeyAuthCacheEntry{
...@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { ...@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil return &APIKeyAuthCacheEntry{NotFound: true}, nil
} }
...@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { ...@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil return nil, redis.Nil
} }
...@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { ...@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60, L1TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1) require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1") _, err := svc.GetByKey(context.Background(), "k-l1")
...@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { ...@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7) svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2) require.Len(t, cache.deleteAuthKeys, 2)
...@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { ...@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60, L2TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9) svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2) require.Len(t, cache.deleteAuthKeys, 2)
...@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { ...@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60, L2TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1") svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1) require.Len(t, cache.deleteAuthKeys, 1)
...@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { ...@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil return nil, redis.Nil
} }
...@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { ...@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true, Singleflight: true,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{}) start := make(chan struct{})
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
......
...@@ -384,6 +384,7 @@ type GatewayService struct { ...@@ -384,6 +384,7 @@ type GatewayService struct {
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
userRepo UserRepository userRepo UserRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache cache GatewayCache
cfg *config.Config cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService schedulerSnapshot *SchedulerSnapshotService
...@@ -405,6 +406,7 @@ func NewGatewayService( ...@@ -405,6 +406,7 @@ func NewGatewayService(
usageLogRepo UsageLogRepository, usageLogRepo UsageLogRepository,
userRepo UserRepository, userRepo UserRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService, schedulerSnapshot *SchedulerSnapshotService,
...@@ -424,6 +426,7 @@ func NewGatewayService( ...@@ -424,6 +426,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
userRepo: userRepo, userRepo: userRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
schedulerSnapshot: schedulerSnapshot, schedulerSnapshot: schedulerSnapshot,
...@@ -4609,10 +4612,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4609,10 +4612,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
// 获取费率倍数 // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
} }
var cost *CostBreakdown var cost *CostBreakdown
...@@ -4773,10 +4783,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4773,10 +4783,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
// 获取费率倍数 // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
} }
var cost *CostBreakdown var cost *CostBreakdown
......
...@@ -21,6 +21,10 @@ type User struct { ...@@ -21,6 +21,10 @@ type User struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64
// TOTP 双因素认证字段 // TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP TotpEnabled bool // 是否启用 TOTP
...@@ -40,18 +44,20 @@ func (u *User) IsActive() bool { ...@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
// CanBindGroup checks whether a user can bind to a given group. // CanBindGroup checks whether a user can bind to a given group.
// For standard groups: // For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list. // - Public groups (non-exclusive): all users can bind
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group. // - Exclusive groups: only users with the group in AllowedGroups can bind
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 { // 公开分组(非专属):所有用户都可以绑定
for _, id := range u.AllowedGroups { if !isExclusive {
if id == groupID { return true
return true }
} // 专属分组:需要在 AllowedGroups 中
for _, id := range u.AllowedGroups {
if id == groupID {
return true
} }
return false
} }
return !isExclusive return false
} }
func (u *User) SetPassword(password string) error { func (u *User) SetPassword(password string) error {
......
package service
import "context"
// UserGroupRateRepository 用户专属分组倍率仓储接口
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
type UserGroupRateRepository interface {
// GetByUserID 获取用户的所有专属分组倍率
// 返回 map[groupID]rateMultiplier
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// 如果未设置专属倍率,返回 nil
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
// SyncUserGroupRates 同步用户的分组专属倍率
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}
-- 用户专属分组倍率表
-- 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
CREATE TABLE IF NOT EXISTS user_group_rate_multipliers (
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
rate_multiplier DECIMAL(10,4) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (user_id, group_id)
);
-- 按 group_id 查询索引(删除分组时清理关联记录)
CREATE INDEX IF NOT EXISTS idx_user_group_rate_multipliers_group_id
ON user_group_rate_multipliers(group_id);
COMMENT ON TABLE user_group_rate_multipliers IS '用户专属分组倍率配置';
COMMENT ON COLUMN user_group_rate_multipliers.user_id IS '用户ID';
COMMENT ON COLUMN user_group_rate_multipliers.group_id IS '分组ID';
COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率(覆盖分组默认倍率)';
...@@ -18,8 +18,18 @@ export async function getAvailable(): Promise<Group[]> { ...@@ -18,8 +18,18 @@ export async function getAvailable(): Promise<Group[]> {
return data return data
} }
/**
* Get current user's custom group rate multipliers
* @returns Map of group_id to custom rate_multiplier
*/
export async function getUserGroupRates(): Promise<Record<number, number>> {
const { data } = await apiClient.get<Record<number, number> | null>('/groups/rates')
return data || {}
}
export const userGroupsAPI = { export const userGroupsAPI = {
getAvailable getAvailable,
getUserGroupRates
} }
export default userGroupsAPI export default userGroupsAPI
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment