Unverified Commit 445bfdf2 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #706 from PMExtra/feat/default-subscriptions-on-user-create

feat(settings): add default subscriptions for new users
parents fc5b9c82 0fba1901
...@@ -33,7 +33,7 @@ func main() { ...@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
......
...@@ -48,7 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -48,7 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redisClient := repository.ProvideRedis(configConfig) redisClient := repository.ProvideRedis(configConfig)
refreshTokenCache := repository.NewRefreshTokenCache(redisClient) refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) groupRepository := repository.NewGroupRepository(client, db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()
...@@ -59,15 +60,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -59,15 +60,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
userGroupRateRepository := repository.NewUserGroupRateRepository(db) userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig) secretEncryptor, err := repository.NewAESEncryptor(configConfig)
...@@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
......
...@@ -180,6 +180,8 @@ require ( ...@@ -180,6 +180,8 @@ require (
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
......
...@@ -51,6 +51,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -51,6 +51,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// Check if ops monitoring is enabled (respects config.ops.enabled) // Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
for _, sub := range settings.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
...@@ -87,6 +94,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -87,6 +94,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback, EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI, FallbackModelOpenAI: settings.FallbackModelOpenAI,
...@@ -146,8 +154,9 @@ type UpdateSettingsRequest struct { ...@@ -146,8 +154,9 @@ type UpdateSettingsRequest struct {
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration // Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"` EnableModelFallback bool `json:"enable_model_fallback"`
...@@ -194,6 +203,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -194,6 +203,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.SMTPPort <= 0 { if req.SMTPPort <= 0 {
req.SMTPPort = 587 req.SMTPPort = 587
} }
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// Turnstile 参数验证 // Turnstile 参数验证
if req.TurnstileEnabled { if req.TurnstileEnabled {
...@@ -300,6 +310,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -300,6 +310,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
req.OpsMetricsIntervalSeconds = &v req.OpsMetricsIntervalSeconds = &v
} }
defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
for _, sub := range req.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
// 验证最低版本号格式(空字符串=禁用,或合法 semver) // 验证最低版本号格式(空字符串=禁用,或合法 semver)
if req.MinClaudeCodeVersion != "" { if req.MinClaudeCodeVersion != "" {
...@@ -343,6 +360,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -343,6 +360,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SoraClientEnabled: req.SoraClientEnabled, SoraClientEnabled: req.SoraClientEnabled,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback, EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic, FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI, FallbackModelOpenAI: req.FallbackModelOpenAI,
...@@ -390,6 +408,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -390,6 +408,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
...@@ -426,6 +451,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -426,6 +451,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SoraClientEnabled: updatedSettings.SoraClientEnabled, SoraClientEnabled: updatedSettings.SoraClientEnabled,
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback, EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
...@@ -547,6 +573,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -547,6 +573,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance { if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance") changed = append(changed, "default_balance")
} }
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
if before.EnableModelFallback != after.EnableModelFallback { if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback") changed = append(changed, "enable_model_fallback")
} }
...@@ -586,6 +615,35 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -586,6 +615,35 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
return changed return changed
} }
func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
if len(input) == 0 {
return nil
}
normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
for _, item := range input {
if item.GroupID <= 0 || item.ValidityDays <= 0 {
continue
}
if item.ValidityDays > service.MaxValidityDays {
item.ValidityDays = service.MaxValidityDays
}
normalized = append(normalized, item)
}
return normalized
}
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host" binding:"required"`
......
...@@ -39,8 +39,9 @@ type SystemSettings struct { ...@@ -39,8 +39,9 @@ type SystemSettings struct {
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration // Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"` EnableModelFallback bool `json:"enable_model_fallback"`
...@@ -62,6 +63,11 @@ type SystemSettings struct { ...@@ -62,6 +63,11 @@ type SystemSettings struct {
MinClaudeCodeVersion string `json:"min_claude_code_version"` MinClaudeCodeVersion string `json:"min_claude_code_version"`
} }
type DefaultSubscriptionSetting struct {
GroupID int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
}
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
......
...@@ -499,6 +499,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -499,6 +499,7 @@ func TestAPIContracts(t *testing.T) {
"doc_url": "https://docs.example.com", "doc_url": "https://docs.example.com",
"default_concurrency": 5, "default_concurrency": 5,
"default_balance": 1.25, "default_balance": 1.25,
"default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_antigravity": "gemini-2.5-pro",
...@@ -620,7 +621,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -620,7 +621,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......
...@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { ...@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil) authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
admin := &service.User{ admin := &service.User{
ID: 1, ID: 1,
......
...@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer ...@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60 cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users} userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc) mw := NewJWTAuthMiddleware(authSvc, userSvc)
......
...@@ -420,6 +420,8 @@ type adminServiceImpl struct { ...@@ -420,6 +420,8 @@ type adminServiceImpl struct {
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
entClient *dbent.Client // 用于开启数据库事务 entClient *dbent.Client // 用于开启数据库事务
settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner
} }
type userGroupRateBatchReader interface { type userGroupRateBatchReader interface {
...@@ -445,6 +447,8 @@ func NewAdminService( ...@@ -445,6 +447,8 @@ func NewAdminService(
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator, authCacheInvalidator APIKeyAuthCacheInvalidator,
entClient *dbent.Client, entClient *dbent.Client,
settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner,
) AdminService { ) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: userRepo, userRepo: userRepo,
...@@ -460,6 +464,8 @@ func NewAdminService( ...@@ -460,6 +464,8 @@ func NewAdminService(
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
entClient: entClient, entClient: entClient,
settingService: settingService,
defaultSubAssigner: defaultSubAssigner,
} }
} }
...@@ -544,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu ...@@ -544,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err return nil, err
} }
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil return user, nil
} }
func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting",
}); err != nil {
logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) { ...@@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) {
require.ErrorIs(t, err, createErr) require.ErrorIs(t, err, createErr)
require.Empty(t, repo.created) require.Empty(t, repo.created)
} }
func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 21}
assigner := &defaultSubscriptionAssignerStub{}
cfg := &config.Config{
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
}
settingService := NewSettingService(&settingRepoStub{values: map[string]string{
SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
}}, cfg)
svc := &adminServiceImpl{
userRepo: repo,
settingService: settingService,
defaultSubAssigner: assigner,
}
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
Email: "new-user@test.com",
Password: "password",
})
require.NoError(t, err)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(21), assigner.calls[0].UserID)
require.Equal(t, int64(5), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
}
...@@ -56,15 +56,20 @@ type JWTClaims struct { ...@@ -56,15 +56,20 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
userRepo UserRepository userRepo UserRepository
redeemRepo RedeemCodeRepository redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache refreshTokenCache RefreshTokenCache
cfg *config.Config cfg *config.Config
settingService *SettingService settingService *SettingService
emailService *EmailService emailService *EmailService
turnstileService *TurnstileService turnstileService *TurnstileService
emailQueueService *EmailQueueService emailQueueService *EmailQueueService
promoService *PromoService promoService *PromoService
defaultSubAssigner DefaultSubscriptionAssigner
}
type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
} }
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
...@@ -78,17 +83,19 @@ func NewAuthService( ...@@ -78,17 +83,19 @@ func NewAuthService(
turnstileService *TurnstileService, turnstileService *TurnstileService,
emailQueueService *EmailQueueService, emailQueueService *EmailQueueService,
promoService *PromoService, promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache, refreshTokenCache: refreshTokenCache,
cfg: cfg, cfg: cfg,
settingService: settingService, settingService: settingService,
emailService: emailService, emailService: emailService,
turnstileService: turnstileService, turnstileService: turnstileService,
emailQueueService: emailQueueService, emailQueueService: emailQueueService,
promoService: promoService, promoService: promoService,
defaultSubAssigner: defaultSubAssigner,
} }
} }
...@@ -188,6 +195,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -188,6 +195,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码) // 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
...@@ -477,6 +485,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -477,6 +485,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
} }
} else { } else {
user = newUser user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
} }
} else { } else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
...@@ -572,6 +581,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -572,6 +581,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
} }
} else { } else {
user = newUser user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
} }
} else { } else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
...@@ -597,6 +607,23 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -597,6 +607,23 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil return tokenPair, user, nil
} }
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting",
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
......
...@@ -56,6 +56,21 @@ type emailCacheStub struct { ...@@ -56,6 +56,21 @@ type emailCacheStub struct {
err error err error
} }
type defaultSubscriptionAssignerStub struct {
calls []AssignSubscriptionInput
err error
}
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
}
if s.err != nil {
return nil, false, s.err
}
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil { if s.err != nil {
return nil, s.err return nil, s.err
...@@ -123,6 +138,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E ...@@ -123,6 +138,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil, nil,
nil, nil,
nil, // promoService nil, // promoService
nil, // defaultSubAssigner
) )
} }
...@@ -381,3 +397,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { ...@@ -381,3 +397,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second) require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
} }
func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Len(t, assigner.calls, 2)
require.Equal(t, int64(42), assigner.calls[0].UserID)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
...@@ -52,6 +52,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier ...@@ -52,6 +52,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
turnstileService, turnstileService,
nil, // emailQueueService nil, // emailQueueService
nil, // promoService nil, // promoService
nil, // defaultSubAssigner
) )
} }
......
...@@ -117,8 +117,9 @@ const ( ...@@ -117,8 +117,9 @@ const (
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src) SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
// 管理员 API Key // 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
......
...@@ -19,10 +19,18 @@ import ( ...@@ -19,10 +19,18 @@ import (
) )
var ( var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
"default subscription group must exist and be subscription type",
)
ErrDefaultSubGroupDuplicate = infraerrors.BadRequest(
"DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE",
"default subscription group cannot be duplicated",
)
) )
type SettingRepository interface { type SettingRepository interface {
...@@ -56,13 +64,19 @@ const minVersionErrorTTL = 5 * time.Second ...@@ -56,13 +64,19 @@ const minVersionErrorTTL = 5 * time.Second
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context // minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
const minVersionDBTimeout = 5 * time.Second const minVersionDBTimeout = 5 * time.Second
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
}
// SettingService 系统设置服务 // SettingService 系统设置服务
type SettingService struct { type SettingService struct {
settingRepo SettingRepository settingRepo SettingRepository
cfg *config.Config defaultSubGroupReader DefaultSubscriptionGroupReader
onUpdate func() // Callback when settings are updated (for cache invalidation) cfg *config.Config
onS3Update func() // Callback when Sora S3 settings are updated onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version onS3Update func() // Callback when Sora S3 settings are updated
version string // Application version
} }
// NewSettingService 创建系统设置服务实例 // NewSettingService 创建系统设置服务实例
...@@ -73,6 +87,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti ...@@ -73,6 +87,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
} }
} }
// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation.
func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) {
s.defaultSubGroupReader = reader
}
// GetAllSettings 获取所有系统设置 // GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx) settings, err := s.settingRepo.GetAll(ctx)
...@@ -222,6 +241,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -222,6 +241,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
return err
}
updates := make(map[string]string) updates := make(map[string]string)
// 注册设置 // 注册设置
...@@ -274,6 +297,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -274,6 +297,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
return fmt.Errorf("marshal default subscriptions: %w", err)
}
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
// Model fallback configuration // Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
...@@ -297,7 +325,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -297,7 +325,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Claude Code version check // Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
err := s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil { if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
minVersionSF.Forget("min_version") minVersionSF.Forget("min_version")
...@@ -312,6 +340,45 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -312,6 +340,45 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
return err return err
} }
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
if len(items) == 0 {
return nil
}
checked := make(map[int64]struct{}, len(items))
for _, item := range items {
if item.GroupID <= 0 {
continue
}
if _, ok := checked[item.GroupID]; ok {
return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{
"group_id": strconv.FormatInt(item.GroupID, 10),
})
}
checked[item.GroupID] = struct{}{}
if s.defaultSubGroupReader == nil {
continue
}
group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID)
if err != nil {
if errors.Is(err, ErrGroupNotFound) {
return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
"group_id": strconv.FormatInt(item.GroupID, 10),
})
}
return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err)
}
if !group.IsSubscriptionType() {
return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
"group_id": strconv.FormatInt(item.GroupID, 10),
})
}
}
return nil
}
// IsRegistrationEnabled 检查是否开放注册 // IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
...@@ -411,6 +478,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { ...@@ -411,6 +478,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance return s.cfg.Default.UserBalance
} }
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
if err != nil {
return nil
}
return parseDefaultSubscriptions(value)
}
// InitializeDefaultSettings 初始化默认设置 // InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置 // 检查是否已有设置
...@@ -435,6 +511,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -435,6 +511,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySoraClientEnabled: "false", SettingKeySoraClientEnabled: "false",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPPort: "587", SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false", SettingKeySMTPUseTLS: "false",
// Model fallback defaults // Model fallback defaults
...@@ -511,6 +588,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -511,6 +588,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else { } else {
result.DefaultBalance = s.cfg.Default.UserBalance result.DefaultBalance = s.cfg.Default.UserBalance
} }
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用 // 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword] result.SMTPPassword = settings[SettingKeySMTPPassword]
...@@ -595,6 +673,31 @@ func isFalseSettingValue(value string) bool { ...@@ -595,6 +673,31 @@ func isFalseSettingValue(value string) bool {
} }
} }
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
var items []DefaultSubscriptionSetting
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return nil
}
normalized := make([]DefaultSubscriptionSetting, 0, len(items))
for _, item := range items {
if item.GroupID <= 0 || item.ValidityDays <= 0 {
continue
}
if item.ValidityDays > MaxValidityDays {
item.ValidityDays = MaxValidityDays
}
normalized = append(normalized, item)
}
return normalized
}
// getStringOrDefault 获取字符串值或默认值 // getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" { if value, ok := settings[key]; ok && value != "" {
......
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
type settingUpdateRepoStub struct {
updates map[string]string
}
func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for k, v := range settings {
s.updates[k] = v
}
return nil
}
func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type defaultSubGroupReaderStub struct {
byID map[int64]*Group
errBy map[int64]error
calls []int64
}
func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) {
s.calls = append(s.calls, id)
if err, ok := s.errBy[id]; ok {
return nil, err
}
if g, ok := s.byID[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) {
repo := &settingUpdateRepoStub{}
groupReader := &defaultSubGroupReaderStub{
byID: map[int64]*Group{
11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
},
}
svc := NewSettingService(repo, &config.Config{})
svc.SetDefaultSubscriptionGroupReader(groupReader)
err := svc.UpdateSettings(context.Background(), &SystemSettings{
DefaultSubscriptions: []DefaultSubscriptionSetting{
{GroupID: 11, ValidityDays: 30},
},
})
require.NoError(t, err)
require.Equal(t, []int64{11}, groupReader.calls)
raw, ok := repo.updates[SettingKeyDefaultSubscriptions]
require.True(t, ok)
var got []DefaultSubscriptionSetting
require.NoError(t, json.Unmarshal([]byte(raw), &got))
require.Equal(t, []DefaultSubscriptionSetting{
{GroupID: 11, ValidityDays: 30},
}, got)
}
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) {
repo := &settingUpdateRepoStub{}
groupReader := &defaultSubGroupReaderStub{
byID: map[int64]*Group{
12: {ID: 12, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := NewSettingService(repo, &config.Config{})
svc.SetDefaultSubscriptionGroupReader(groupReader)
err := svc.UpdateSettings(context.Background(), &SystemSettings{
DefaultSubscriptions: []DefaultSubscriptionSetting{
{GroupID: 12, ValidityDays: 7},
},
})
require.Error(t, err)
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
require.Nil(t, repo.updates)
}
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) {
repo := &settingUpdateRepoStub{}
groupReader := &defaultSubGroupReaderStub{
errBy: map[int64]error{
13: ErrGroupNotFound,
},
}
svc := NewSettingService(repo, &config.Config{})
svc.SetDefaultSubscriptionGroupReader(groupReader)
err := svc.UpdateSettings(context.Background(), &SystemSettings{
DefaultSubscriptions: []DefaultSubscriptionSetting{
{GroupID: 13, ValidityDays: 7},
},
})
require.Error(t, err)
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"])
require.Nil(t, repo.updates)
}
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) {
repo := &settingUpdateRepoStub{}
groupReader := &defaultSubGroupReaderStub{
byID: map[int64]*Group{
11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
},
}
svc := NewSettingService(repo, &config.Config{})
svc.SetDefaultSubscriptionGroupReader(groupReader)
err := svc.UpdateSettings(context.Background(), &SystemSettings{
DefaultSubscriptions: []DefaultSubscriptionSetting{
{GroupID: 11, ValidityDays: 30},
{GroupID: 11, ValidityDays: 60},
},
})
require.Error(t, err)
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
require.Nil(t, repo.updates)
}
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) {
repo := &settingUpdateRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
DefaultSubscriptions: []DefaultSubscriptionSetting{
{GroupID: 11, ValidityDays: 30},
{GroupID: 11, ValidityDays: 60},
},
})
require.Error(t, err)
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
require.Nil(t, repo.updates)
}
func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
require.Equal(t, []DefaultSubscriptionSetting{
{GroupID: 11, ValidityDays: 30},
{GroupID: 11, ValidityDays: 60},
{GroupID: 12, ValidityDays: MaxValidityDays},
}, got)
}
...@@ -41,8 +41,9 @@ type SystemSettings struct { ...@@ -41,8 +41,9 @@ type SystemSettings struct {
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
SoraClientEnabled bool SoraClientEnabled bool
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
DefaultSubscriptions []DefaultSubscriptionSetting
// Model fallback configuration // Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"` EnableModelFallback bool `json:"enable_model_fallback"`
...@@ -65,6 +66,11 @@ type SystemSettings struct { ...@@ -65,6 +66,11 @@ type SystemSettings struct {
MinClaudeCodeVersion string MinClaudeCodeVersion string
} }
type DefaultSubscriptionSetting struct {
GroupID int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
}
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
......
...@@ -284,6 +284,13 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC ...@@ -284,6 +284,13 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
return apiKeyService return apiKeyService
} }
// ProvideSettingService wires SettingService with group reader for default subscription validation.
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
svc := NewSettingService(settingRepo, cfg)
svc.SetDefaultSubscriptionGroupReader(groupRepo)
return svc
}
// ProviderSet is the Wire provider set for all services // ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
// Core services // Core services
...@@ -326,7 +333,7 @@ var ProviderSet = wire.NewSet( ...@@ -326,7 +333,7 @@ var ProviderSet = wire.NewSet(
ProvideRateLimitService, ProvideRateLimitService,
NewAccountUsageService, NewAccountUsageService,
NewAccountTestService, NewAccountTestService,
NewSettingService, ProvideSettingService,
NewDataManagementService, NewDataManagementService,
ProvideOpsSystemLogSink, ProvideOpsSystemLogSink,
NewOpsService, NewOpsService,
...@@ -339,6 +346,7 @@ var ProviderSet = wire.NewSet( ...@@ -339,6 +346,7 @@ var ProviderSet = wire.NewSet(
ProvideEmailQueueService, ProvideEmailQueueService,
NewTurnstileService, NewTurnstileService,
NewSubscriptionService, NewSubscriptionService,
wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)),
ProvideConcurrencyService, ProvideConcurrencyService,
NewUsageRecordWorkerPool, NewUsageRecordWorkerPool,
ProvideSchedulerSnapshotService, ProvideSchedulerSnapshotService,
......
...@@ -5,6 +5,11 @@ ...@@ -5,6 +5,11 @@
import { apiClient } from '../client' import { apiClient } from '../client'
export interface DefaultSubscriptionSetting {
group_id: number
validity_days: number
}
/** /**
* System settings interface * System settings interface
*/ */
...@@ -20,6 +25,7 @@ export interface SystemSettings { ...@@ -20,6 +25,7 @@ export interface SystemSettings {
// Default settings // Default settings
default_balance: number default_balance: number
default_concurrency: number default_concurrency: number
default_subscriptions: DefaultSubscriptionSetting[]
// OEM settings // OEM settings
site_name: string site_name: string
site_logo: string site_logo: string
...@@ -81,6 +87,7 @@ export interface UpdateSettingsRequest { ...@@ -81,6 +87,7 @@ export interface UpdateSettingsRequest {
totp_enabled?: boolean // TOTP 双因素认证 totp_enabled?: boolean // TOTP 双因素认证
default_balance?: number default_balance?: number
default_concurrency?: number default_concurrency?: number
default_subscriptions?: DefaultSubscriptionSetting[]
site_name?: string site_name?: string
site_logo?: string site_logo?: string
site_subtitle?: string site_subtitle?: string
......
...@@ -3555,7 +3555,15 @@ export default { ...@@ -3555,7 +3555,15 @@ export default {
defaultBalance: 'Default Balance', defaultBalance: 'Default Balance',
defaultBalanceHint: 'Initial balance for new users', defaultBalanceHint: 'Initial balance for new users',
defaultConcurrency: 'Default Concurrency', defaultConcurrency: 'Default Concurrency',
defaultConcurrencyHint: 'Maximum concurrent requests for new users' defaultConcurrencyHint: 'Maximum concurrent requests for new users',
defaultSubscriptions: 'Default Subscriptions',
defaultSubscriptionsHint: 'Auto-assign these subscriptions when a new user is created or registered',
addDefaultSubscription: 'Add Default Subscription',
defaultSubscriptionsEmpty: 'No default subscriptions configured.',
defaultSubscriptionsDuplicate:
'Duplicate subscription group: {groupId}. Each group can only appear once.',
subscriptionGroup: 'Subscription Group',
subscriptionValidityDays: 'Validity (days)'
}, },
claudeCode: { claudeCode: {
title: 'Claude Code Settings', title: 'Claude Code Settings',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment