Commit 22f07a7b authored by shaw's avatar shaw
Browse files

Merge PR #36: refactor: 调整项目结构为单向依赖

parents ecb2c535 e5a77853
package service
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64
Code string
Type string
Value float64
Status string
UsedBy *int64
UsedAt *time.Time
Notes string
CreatedAt time.Time
GroupID *int64
ValidityDays int
User *User
Group *Group
}
func (r *RedeemCode) IsUsed() bool {
return r.Status == StatusUsed
}
func (r *RedeemCode) CanUse() bool {
return r.Status == StatusUnused
}
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
......@@ -10,7 +10,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
)
......@@ -39,17 +38,17 @@ type RedeemCache interface {
}
type RedeemCodeRepository interface {
Create(ctx context.Context, code *model.RedeemCode) error
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
Update(ctx context.Context, code *model.RedeemCode) error
Create(ctx context.Context, code *RedeemCode) error
CreateBatch(ctx context.Context, codes []RedeemCode) error
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
Update(ctx context.Context, code *RedeemCode) error
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
}
// GenerateCodesRequest 生成兑换码请求
......@@ -116,7 +115,7 @@ func (s *RedeemService) GenerateRandomCode() (string, error) {
}
// GenerateCodes 批量生成兑换码
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) {
if req.Count <= 0 {
return nil, errors.New("count must be greater than 0")
}
......@@ -131,21 +130,21 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codeType := req.Type
if codeType == "" {
codeType = model.RedeemTypeBalance
codeType = RedeemTypeBalance
}
codes := make([]model.RedeemCode, 0, req.Count)
codes := make([]RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
if err != nil {
return nil, fmt.Errorf("generate code: %w", err)
}
codes = append(codes, model.RedeemCode{
codes = append(codes, RedeemCode{
Code: code,
Type: codeType,
Value: req.Value,
Status: model.StatusUnused,
Status: StatusUnused,
})
}
......@@ -210,7 +209,7 @@ func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
}
// Redeem 使用兑换码
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) {
// 检查限流
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
return nil, err
......@@ -239,7 +238,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
}
......@@ -261,7 +260,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case model.RedeemTypeBalance:
case RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
......@@ -275,13 +274,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}()
}
case model.RedeemTypeConcurrency:
case RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case model.RedeemTypeSubscription:
case RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
validityDays = 30
......@@ -320,7 +319,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get redeem code: %w", err)
......@@ -329,7 +328,7 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
}
// GetByCode 根据Code获取兑换码
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
return nil, fmt.Errorf("get redeem code: %w", err)
......@@ -338,7 +337,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
......@@ -383,7 +382,7 @@ func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
}
// GetUserHistory 获取用户的兑换历史
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
if err != nil {
return nil, fmt.Errorf("get user redeem history: %w", err)
......
package service
import "time"
type Setting struct {
ID int64
Key string
Value string
UpdatedAt time.Time
}
......@@ -10,7 +10,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var (
......@@ -19,7 +18,7 @@ var (
)
type SettingRepository interface {
Get(ctx context.Context, key string) (*model.Setting, error)
Get(ctx context.Context, key string) (*Setting, error)
GetValue(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key, value string) error
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
......@@ -43,7 +42,7 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
if err != nil {
return nil, fmt.Errorf("get all settings: %w", err)
......@@ -53,18 +52,18 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSetti
}
// GetPublicSettings 获取公开设置(无需登录)
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) {
keys := []string{
model.SettingKeyRegistrationEnabled,
model.SettingKeyEmailVerifyEnabled,
model.SettingKeyTurnstileEnabled,
model.SettingKeyTurnstileSiteKey,
model.SettingKeySiteName,
model.SettingKeySiteLogo,
model.SettingKeySiteSubtitle,
model.SettingKeyApiBaseUrl,
model.SettingKeyContactInfo,
model.SettingKeyDocUrl,
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyApiBaseUrl,
SettingKeyContactInfo,
SettingKeyDocUrl,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -72,64 +71,64 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSe
return nil, fmt.Errorf("get public settings: %w", err)
}
return &model.PublicSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
DocUrl: settings[model.SettingKeyDocUrl],
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
}, nil
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
updates := make(map[string]string)
// 注册设置
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[model.SettingKeySmtpHost] = settings.SmtpHost
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
updates[SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
if settings.TurnstileSecretKey != "" {
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
}
// OEM设置
updates[model.SettingKeySiteName] = settings.SiteName
updates[model.SettingKeySiteLogo] = settings.SiteLogo
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[model.SettingKeyContactInfo] = settings.ContactInfo
updates[model.SettingKeyDocUrl] = settings.DocUrl
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl
// 默认配置
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
return s.settingRepo.SetMultiple(ctx, updates)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
if err != nil {
// 默认开放注册
return true
......@@ -139,7 +138,7 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
if err != nil {
return false
}
......@@ -148,7 +147,7 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || value == "" {
return "Sub2API"
}
......@@ -157,7 +156,7 @@ func (s *SettingService) GetSiteName(ctx context.Context) string {
// GetDefaultConcurrency 获取默认并发量
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultConcurrency)
if err != nil {
return s.cfg.Default.UserConcurrency
}
......@@ -169,7 +168,7 @@ func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
// GetDefaultBalance 获取默认余额
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultBalance)
if err != nil {
return s.cfg.Default.UserBalance
}
......@@ -182,7 +181,7 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
_, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
if err == nil {
// 已有设置,不需要初始化
return nil
......@@ -193,62 +192,62 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
model.SettingKeyRegistrationEnabled: "true",
model.SettingKeyEmailVerifyEnabled: "false",
model.SettingKeySiteName: "Sub2API",
model.SettingKeySiteLogo: "",
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
model.SettingKeySmtpPort: "587",
model.SettingKeySmtpUseTLS: "false",
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
}
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
result := &model.SystemSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[model.SettingKeySmtpHost],
SmtpUsername: settings[model.SettingKeySmtpUsername],
SmtpFrom: settings[model.SettingKeySmtpFrom],
SmtpFromName: settings[model.SettingKeySmtpFromName],
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
DocUrl: settings[model.SettingKeyDocUrl],
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
result.DefaultConcurrency = concurrency
} else {
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
return result
}
......@@ -263,7 +262,7 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled)
if err != nil {
return false
}
......@@ -272,7 +271,7 @@ func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
// GetTurnstileSecretKey 获取 Turnstile Secret Key
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileSecretKey)
if err != nil {
return ""
}
......@@ -287,10 +286,10 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes)
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
......@@ -300,7 +299,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
......@@ -324,7 +323,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
......@@ -336,5 +335,5 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey)
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
}
package service
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
ContactInfo string
DocUrl string
DefaultConcurrency int
DefaultBalance float64
}
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
ContactInfo string
DocUrl string
Version string
}
......@@ -7,7 +7,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
......@@ -48,7 +47,7 @@ type AssignSubscriptionInput struct {
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
......@@ -91,7 +90,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
......@@ -132,8 +131,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 如果订阅已过期或被暂停,恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
if existingSub.Status != SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
......@@ -185,19 +184,19 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
now := time.Now()
sub := &model.UserSubscription{
sub := &UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
Status: SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
......@@ -229,14 +228,14 @@ type BulkAssignSubscriptionInput struct {
type BulkAssignResult struct {
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Subscriptions []UserSubscription
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Subscriptions: make([]UserSubscription, 0),
Errors: make([]string, 0),
}
......@@ -286,7 +285,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
......@@ -299,8 +298,8 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 如果订阅已过期,恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
if sub.Status == SubscriptionStatusExpired {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, SubscriptionStatusActive); err != nil {
return nil, err
}
}
......@@ -319,12 +318,12 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubscription, error) {
return s.userSubRepo.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
......@@ -333,7 +332,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListByUserID(ctx, userID)
if err != nil {
return nil, err
......@@ -343,7 +342,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
......@@ -353,7 +352,7 @@ func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, u
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
......@@ -364,7 +363,7 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
if err != nil {
......@@ -376,7 +375,7 @@ func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, user
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
func normalizeExpiredWindows(subs []model.UserSubscription) {
func normalizeExpiredWindows(subs []UserSubscription) {
for i := range subs {
sub := &subs[i]
// 日窗口过期:清零展示数据
......@@ -403,7 +402,7 @@ func startOfDay(t time.Time) time.Time {
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *UserSubscription) error {
if sub.IsWindowActivated() {
return nil
}
......@@ -414,7 +413,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间
windowStart := startOfDay(time.Now())
needsInvalidateCache := false
......@@ -458,7 +457,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
......@@ -620,16 +619,16 @@ func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (i
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
if sub.Status == model.SubscriptionStatusExpired {
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
if sub.Status == SubscriptionStatusExpired {
return ErrSubscriptionExpired
}
if sub.Status == model.SubscriptionStatusSuspended {
if sub.Status == SubscriptionStatusSuspended {
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil
......
......@@ -8,7 +8,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// TokenRefreshService OAuth token自动刷新服务
......@@ -142,19 +141,19 @@ func (s *TokenRefreshService) processRefresh() {
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) {
return s.accountRepo.ListActive(ctx)
}
// refreshWithRetry 带重试的刷新
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
var lastErr error
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功,更新账号credentials
account.Credentials = model.JSONB(newCredentials)
account.Credentials = newCredentials
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
......
......@@ -4,22 +4,20 @@ import (
"context"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
type TokenRefresher interface {
// CanRefresh 检查此刷新器是否能处理指定账号
CanRefresh(account *model.Account) bool
CanRefresh(account *Account) bool
// NeedsRefresh 检查账号的token是否需要刷新
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
NeedsRefresh(account *Account, refreshWindow time.Duration) bool
// Refresh 执行token刷新,返回更新后的credentials
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
Refresh(ctx context.Context, account *model.Account) (map[string]any, error)
Refresh(ctx context.Context, account *Account) (map[string]any, error)
}
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
......@@ -37,14 +35,14 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
// CanRefresh 检查是否能处理此账号
// 只处理 anthropic 平台的 oauth 类型账号
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformAnthropic &&
account.Type == model.AccountTypeOAuth
func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAnthropic &&
account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
......@@ -61,7 +59,7 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
// Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
......@@ -103,14 +101,14 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAIToke
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号
func (r *OpenAITokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformOpenAI &&
account.Type == model.AccountTypeOAuth
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformOpenAI &&
account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false
......@@ -121,7 +119,7 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
// Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
......
package service
import "time"
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64
UserID int64
ApiKeyID int64
AccountID int64
RequestID string
Model string
GroupID *int64
SubscriptionID *int64
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
InputCost float64
OutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
ActualCost float64
RateMultiplier float64
BillingType int8
Stream bool
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time
User *User
ApiKey *ApiKey
Account *Account
Group *Group
Subscription *UserSubscription
}
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}
......@@ -6,7 +6,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
......@@ -66,7 +65,7 @@ func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *Usa
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
......@@ -74,7 +73,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// 创建使用日志
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
......@@ -112,7 +111,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get usage log: %w", err)
......@@ -121,7 +120,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
......@@ -130,7 +129,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
......@@ -139,7 +138,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
......@@ -243,7 +242,7 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
}
// calculateStats 计算统计数据
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
func (s *UsageService) calculateStats(logs []UsageLog) *UsageStats {
stats := &UsageStats{}
for _, log := range logs {
......@@ -313,7 +312,7 @@ func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs [
}
// ListWithFilters lists usage logs with admin filters.
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) {
logs, result, err := s.usageRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs with filters: %w", err)
......
package service
import (
"time"
"golang.org/x/crypto/bcrypt"
)
type User struct {
ID int64
Email string
Username string
Wechat string
Notes string
PasswordHash string
Role string
Balance float64
Concurrency int
Status string
AllowedGroups []int64
CreatedAt time.Time
UpdatedAt time.Time
ApiKeys []ApiKey
Subscriptions []UserSubscription
}
func (u *User) IsAdmin() bool {
return u.Role == RoleAdmin
}
func (u *User) IsActive() bool {
return u.Status == StatusActive
}
// CanBindGroup checks whether a user can bind to a given group.
// For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
return !isExclusive
}
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
func (u *User) CheckPassword(password string) bool {
return bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil
}
......@@ -5,9 +5,7 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt"
)
var (
......@@ -17,15 +15,15 @@ var (
)
type UserRepository interface {
Create(ctx context.Context, user *model.User) error
GetByID(ctx context.Context, id int64) (*model.User, error)
GetByEmail(ctx context.Context, email string) (*model.User, error)
GetFirstAdmin(ctx context.Context) (*model.User, error)
Update(ctx context.Context, user *model.User) error
Create(ctx context.Context, user *User) error
GetByID(ctx context.Context, id int64) (*User, error)
GetByEmail(ctx context.Context, email string) (*User, error)
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
......@@ -61,7 +59,7 @@ func NewUserService(userRepo UserRepository) *UserService {
}
// GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证)
func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) {
admin, err := s.userRepo.GetFirstAdmin(ctx)
if err != nil {
return nil, fmt.Errorf("get first admin: %w", err)
......@@ -70,7 +68,7 @@ func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
......@@ -79,7 +77,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
......@@ -125,18 +123,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
}
// 验证当前密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
if !user.CheckPassword(req.CurrentPassword) {
return ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
if err := user.SetPassword(req.NewPassword); err != nil {
return fmt.Errorf("set password: %w", err)
}
user.PasswordHash = string(hashedPassword)
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
......@@ -145,7 +139,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
}
// GetByID 根据ID获取用户(管理员功能)
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
......@@ -154,7 +148,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
......
package model
package service
import (
"time"
)
import "time"
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间(nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度(USD,基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
ID int64
UserID int64
GroupID int64
StartsAt time.Time
ExpiresAt time.Time
Status string
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64
WeeklyUsageUSD float64
MonthlyUsageUSD float64
AssignedBy *int64
AssignedAt time.Time
Notes string
CreatedAt time.Time
UpdatedAt time.Time
func (UserSubscription) TableName() string {
return "user_subscriptions"
User *User
Group *Group
AssignedByUser *User
}
// IsActive 检查订阅是否有效(状态为active且未过期)
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
......@@ -68,12 +46,10 @@ func (s *UserSubscription) DaysRemaining() int {
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
......@@ -81,7 +57,6 @@ func (s *UserSubscription) NeedsDailyReset() bool {
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
......@@ -89,7 +64,6 @@ func (s *UserSubscription) NeedsWeeklyReset() bool {
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
......@@ -97,7 +71,6 @@ func (s *UserSubscription) NeedsMonthlyReset() bool {
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
......@@ -106,7 +79,6 @@ func (s *UserSubscription) DailyResetTime() *time.Time {
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
......@@ -115,7 +87,6 @@ func (s *UserSubscription) WeeklyResetTime() *time.Time {
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
......@@ -124,31 +95,27 @@ func (s *UserSubscription) MonthlyResetTime() *time.Time {
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
return true
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
return true
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
return true
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
......
......@@ -4,22 +4,21 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type UserSubscriptionRepository interface {
Create(ctx context.Context, sub *model.UserSubscription) error
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
Update(ctx context.Context, sub *model.UserSubscription) error
Create(ctx context.Context, sub *UserSubscription) error
GetByID(ctx context.Context, id int64) (*UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error)
Update(ctx context.Context, sub *UserSubscription) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
......
......@@ -10,10 +10,10 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
"gorm.io/driver/postgres"
"gorm.io/gorm"
......@@ -271,8 +271,7 @@ func initializeDatabase(cfg *SetupConfig) error {
}
}()
// 使用 model 包的 AutoMigrate,确保模型定义统一
return model.AutoMigrate(db)
return repository.AutoMigrate(db)
}
func createAdminUser(cfg *SetupConfig) error {
......@@ -299,29 +298,28 @@ func createAdminUser(cfg *SetupConfig) error {
// Check if admin already exists
var count int64
db.Model(&model.User{}).Where("role = ?", "admin").Count(&count)
if err := db.Table("users").Where("role = ?", service.RoleAdmin).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return nil // Admin already exists
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.Admin.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
// Create admin user
admin := &model.User{
admin := &service.User{
Email: cfg.Admin.Email,
PasswordHash: string(hashedPassword),
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 0,
Concurrency: 5,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
return db.Create(admin).Error
if err := admin.SetPassword(cfg.Admin.Password); err != nil {
return err
}
return repository.NewUserRepository(db).Create(context.Background(), admin)
}
func writeConfigFile(cfg *SetupConfig) error {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment