Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
e5a77853
Commit
e5a77853
authored
Dec 26, 2025
by
Forest
Browse files
refactor: 调整项目结构为单向依赖
parent
b3463769
Changes
95
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/redeem_code.go
0 → 100644
View file @
e5a77853
package
service
import
(
"crypto/rand"
"encoding/hex"
"time"
)
type
RedeemCode
struct
{
ID
int64
Code
string
Type
string
Value
float64
Status
string
UsedBy
*
int64
UsedAt
*
time
.
Time
Notes
string
CreatedAt
time
.
Time
GroupID
*
int64
ValidityDays
int
User
*
User
Group
*
Group
}
func
(
r
*
RedeemCode
)
IsUsed
()
bool
{
return
r
.
Status
==
StatusUsed
}
func
(
r
*
RedeemCode
)
CanUse
()
bool
{
return
r
.
Status
==
StatusUnused
}
func
GenerateRedeemCode
()
(
string
,
error
)
{
b
:=
make
([]
byte
,
16
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
b
),
nil
}
backend/internal/service/redeem_service.go
View file @
e5a77853
...
...
@@ -10,7 +10,6 @@ import (
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
)
...
...
@@ -39,17 +38,17 @@ type RedeemCache interface {
}
type
RedeemCodeRepository
interface
{
Create
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
CreateBatch
(
ctx
context
.
Context
,
codes
[]
model
.
RedeemCode
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
model
.
RedeemCode
,
error
)
Update
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
Create
(
ctx
context
.
Context
,
code
*
RedeemCode
)
error
CreateBatch
(
ctx
context
.
Context
,
codes
[]
RedeemCode
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
RedeemCode
,
error
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
RedeemCode
,
error
)
Update
(
ctx
context
.
Context
,
code
*
RedeemCode
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
Use
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
model
.
RedeemCode
,
error
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
RedeemCode
,
error
)
}
// GenerateCodesRequest 生成兑换码请求
...
...
@@ -116,7 +115,7 @@ func (s *RedeemService) GenerateRandomCode() (string, error) {
}
// GenerateCodes 批量生成兑换码
func
(
s
*
RedeemService
)
GenerateCodes
(
ctx
context
.
Context
,
req
GenerateCodesRequest
)
([]
model
.
RedeemCode
,
error
)
{
func
(
s
*
RedeemService
)
GenerateCodes
(
ctx
context
.
Context
,
req
GenerateCodesRequest
)
([]
RedeemCode
,
error
)
{
if
req
.
Count
<=
0
{
return
nil
,
errors
.
New
(
"count must be greater than 0"
)
}
...
...
@@ -131,21 +130,21 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codeType
:=
req
.
Type
if
codeType
==
""
{
codeType
=
model
.
RedeemTypeBalance
codeType
=
RedeemTypeBalance
}
codes
:=
make
([]
model
.
RedeemCode
,
0
,
req
.
Count
)
codes
:=
make
([]
RedeemCode
,
0
,
req
.
Count
)
for
i
:=
0
;
i
<
req
.
Count
;
i
++
{
code
,
err
:=
s
.
GenerateRandomCode
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate code: %w"
,
err
)
}
codes
=
append
(
codes
,
model
.
RedeemCode
{
codes
=
append
(
codes
,
RedeemCode
{
Code
:
code
,
Type
:
codeType
,
Value
:
req
.
Value
,
Status
:
model
.
StatusUnused
,
Status
:
StatusUnused
,
})
}
...
...
@@ -210,7 +209,7 @@ func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
}
// Redeem 使用兑换码
func
(
s
*
RedeemService
)
Redeem
(
ctx
context
.
Context
,
userID
int64
,
code
string
)
(
*
model
.
RedeemCode
,
error
)
{
func
(
s
*
RedeemService
)
Redeem
(
ctx
context
.
Context
,
userID
int64
,
code
string
)
(
*
RedeemCode
,
error
)
{
// 检查限流
if
err
:=
s
.
checkRedeemRateLimit
(
ctx
,
userID
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -239,7 +238,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
// 验证兑换码类型的前置条件
if
redeemCode
.
Type
==
model
.
RedeemTypeSubscription
&&
redeemCode
.
GroupID
==
nil
{
if
redeemCode
.
Type
==
RedeemTypeSubscription
&&
redeemCode
.
GroupID
==
nil
{
return
nil
,
infraerrors
.
BadRequest
(
"REDEEM_CODE_INVALID"
,
"invalid subscription redeem code: missing group_id"
)
}
...
...
@@ -261,7 +260,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch
redeemCode
.
Type
{
case
model
.
RedeemTypeBalance
:
case
RedeemTypeBalance
:
// 增加用户余额
if
err
:=
s
.
userRepo
.
UpdateBalance
(
ctx
,
userID
,
redeemCode
.
Value
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user balance: %w"
,
err
)
...
...
@@ -275,13 +274,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}()
}
case
model
.
RedeemTypeConcurrency
:
case
RedeemTypeConcurrency
:
// 增加用户并发数
if
err
:=
s
.
userRepo
.
UpdateConcurrency
(
ctx
,
userID
,
int
(
redeemCode
.
Value
));
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user concurrency: %w"
,
err
)
}
case
model
.
RedeemTypeSubscription
:
case
RedeemTypeSubscription
:
validityDays
:=
redeemCode
.
ValidityDays
if
validityDays
<=
0
{
validityDays
=
30
...
...
@@ -320,7 +319,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
// GetByID 根据ID获取兑换码
func
(
s
*
RedeemService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
{
func
(
s
*
RedeemService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
RedeemCode
,
error
)
{
code
,
err
:=
s
.
redeemRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get redeem code: %w"
,
err
)
...
...
@@ -329,7 +328,7 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
}
// GetByCode 根据Code获取兑换码
func
(
s
*
RedeemService
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
model
.
RedeemCode
,
error
)
{
func
(
s
*
RedeemService
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
RedeemCode
,
error
)
{
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
code
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get redeem code: %w"
,
err
)
...
...
@@ -338,7 +337,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
}
// List 获取兑换码列表(管理员功能)
func
(
s
*
RedeemService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
RedeemService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
codes
,
pagination
,
err
:=
s
.
redeemRepo
.
List
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list redeem codes: %w"
,
err
)
...
...
@@ -383,7 +382,7 @@ func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
}
// GetUserHistory 获取用户的兑换历史
func
(
s
*
RedeemService
)
GetUserHistory
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
model
.
RedeemCode
,
error
)
{
func
(
s
*
RedeemService
)
GetUserHistory
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
RedeemCode
,
error
)
{
codes
,
err
:=
s
.
redeemRepo
.
ListByUser
(
ctx
,
userID
,
limit
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user redeem history: %w"
,
err
)
...
...
backend/internal/service/setting.go
0 → 100644
View file @
e5a77853
package
service
import
"time"
type
Setting
struct
{
ID
int64
Key
string
Value
string
UpdatedAt
time
.
Time
}
backend/internal/service/setting_service.go
View file @
e5a77853
...
...
@@ -10,7 +10,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var
(
...
...
@@ -19,7 +18,7 @@ var (
)
type
SettingRepository
interface
{
Get
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
Setting
,
error
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
...
...
@@ -43,7 +42,7 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
}
// GetAllSettings 获取所有系统设置
func
(
s
*
SettingService
)
GetAllSettings
(
ctx
context
.
Context
)
(
*
model
.
SystemSettings
,
error
)
{
func
(
s
*
SettingService
)
GetAllSettings
(
ctx
context
.
Context
)
(
*
SystemSettings
,
error
)
{
settings
,
err
:=
s
.
settingRepo
.
GetAll
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get all settings: %w"
,
err
)
...
...
@@ -53,18 +52,18 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSetti
}
// GetPublicSettings 获取公开设置(无需登录)
func
(
s
*
SettingService
)
GetPublicSettings
(
ctx
context
.
Context
)
(
*
model
.
PublicSettings
,
error
)
{
func
(
s
*
SettingService
)
GetPublicSettings
(
ctx
context
.
Context
)
(
*
PublicSettings
,
error
)
{
keys
:=
[]
string
{
model
.
SettingKeyRegistrationEnabled
,
model
.
SettingKeyEmailVerifyEnabled
,
model
.
SettingKeyTurnstileEnabled
,
model
.
SettingKeyTurnstileSiteKey
,
model
.
SettingKeySiteName
,
model
.
SettingKeySiteLogo
,
model
.
SettingKeySiteSubtitle
,
model
.
SettingKeyApiBaseUrl
,
model
.
SettingKeyContactInfo
,
model
.
SettingKeyDocUrl
,
SettingKeyRegistrationEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileSiteKey
,
SettingKeySiteName
,
SettingKeySiteLogo
,
SettingKeySiteSubtitle
,
SettingKeyApiBaseUrl
,
SettingKeyContactInfo
,
SettingKeyDocUrl
,
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
...
...
@@ -72,64 +71,64 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSe
return
nil
,
fmt
.
Errorf
(
"get public settings: %w"
,
err
)
}
return
&
model
.
PublicSettings
{
RegistrationEnabled
:
settings
[
model
.
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
model
.
SettingKeyEmailVerifyEnabled
]
==
"true"
,
TurnstileEnabled
:
settings
[
model
.
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
model
.
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
model
.
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
model
.
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
model
.
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
ApiBaseUrl
:
settings
[
model
.
SettingKeyApiBaseUrl
],
ContactInfo
:
settings
[
model
.
SettingKeyContactInfo
],
DocUrl
:
settings
[
model
.
SettingKeyDocUrl
],
return
&
PublicSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
ApiBaseUrl
:
settings
[
SettingKeyApiBaseUrl
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocUrl
:
settings
[
SettingKeyDocUrl
],
},
nil
}
// UpdateSettings 更新系统设置
func
(
s
*
SettingService
)
UpdateSettings
(
ctx
context
.
Context
,
settings
*
model
.
SystemSettings
)
error
{
func
(
s
*
SettingService
)
UpdateSettings
(
ctx
context
.
Context
,
settings
*
SystemSettings
)
error
{
updates
:=
make
(
map
[
string
]
string
)
// 注册设置
updates
[
model
.
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
model
.
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
// 邮件服务设置(只有非空才更新密码)
updates
[
model
.
SettingKeySmtpHost
]
=
settings
.
SmtpHost
updates
[
model
.
SettingKeySmtpPort
]
=
strconv
.
Itoa
(
settings
.
SmtpPort
)
updates
[
model
.
SettingKeySmtpUsername
]
=
settings
.
SmtpUsername
updates
[
SettingKeySmtpHost
]
=
settings
.
SmtpHost
updates
[
SettingKeySmtpPort
]
=
strconv
.
Itoa
(
settings
.
SmtpPort
)
updates
[
SettingKeySmtpUsername
]
=
settings
.
SmtpUsername
if
settings
.
SmtpPassword
!=
""
{
updates
[
model
.
SettingKeySmtpPassword
]
=
settings
.
SmtpPassword
updates
[
SettingKeySmtpPassword
]
=
settings
.
SmtpPassword
}
updates
[
model
.
SettingKeySmtpFrom
]
=
settings
.
SmtpFrom
updates
[
model
.
SettingKeySmtpFromName
]
=
settings
.
SmtpFromName
updates
[
model
.
SettingKeySmtpUseTLS
]
=
strconv
.
FormatBool
(
settings
.
SmtpUseTLS
)
updates
[
SettingKeySmtpFrom
]
=
settings
.
SmtpFrom
updates
[
SettingKeySmtpFromName
]
=
settings
.
SmtpFromName
updates
[
SettingKeySmtpUseTLS
]
=
strconv
.
FormatBool
(
settings
.
SmtpUseTLS
)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates
[
model
.
SettingKeyTurnstileEnabled
]
=
strconv
.
FormatBool
(
settings
.
TurnstileEnabled
)
updates
[
model
.
SettingKeyTurnstileSiteKey
]
=
settings
.
TurnstileSiteKey
updates
[
SettingKeyTurnstileEnabled
]
=
strconv
.
FormatBool
(
settings
.
TurnstileEnabled
)
updates
[
SettingKeyTurnstileSiteKey
]
=
settings
.
TurnstileSiteKey
if
settings
.
TurnstileSecretKey
!=
""
{
updates
[
model
.
SettingKeyTurnstileSecretKey
]
=
settings
.
TurnstileSecretKey
updates
[
SettingKeyTurnstileSecretKey
]
=
settings
.
TurnstileSecretKey
}
// OEM设置
updates
[
model
.
SettingKeySiteName
]
=
settings
.
SiteName
updates
[
model
.
SettingKeySiteLogo
]
=
settings
.
SiteLogo
updates
[
model
.
SettingKeySiteSubtitle
]
=
settings
.
SiteSubtitle
updates
[
model
.
SettingKeyApiBaseUrl
]
=
settings
.
ApiBaseUrl
updates
[
model
.
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
model
.
SettingKeyDocUrl
]
=
settings
.
DocUrl
updates
[
SettingKeySiteName
]
=
settings
.
SiteName
updates
[
SettingKeySiteLogo
]
=
settings
.
SiteLogo
updates
[
SettingKeySiteSubtitle
]
=
settings
.
SiteSubtitle
updates
[
SettingKeyApiBaseUrl
]
=
settings
.
ApiBaseUrl
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyDocUrl
]
=
settings
.
DocUrl
// 默认配置
updates
[
model
.
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
model
.
SettingKeyDefaultBalance
]
=
strconv
.
FormatFloat
(
settings
.
DefaultBalance
,
'f'
,
8
,
64
)
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
SettingKeyDefaultBalance
]
=
strconv
.
FormatFloat
(
settings
.
DefaultBalance
,
'f'
,
8
,
64
)
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
}
// IsRegistrationEnabled 检查是否开放注册
func
(
s
*
SettingService
)
IsRegistrationEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyRegistrationEnabled
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyRegistrationEnabled
)
if
err
!=
nil
{
// 默认开放注册
return
true
...
...
@@ -139,7 +138,7 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
// IsEmailVerifyEnabled 检查是否开启邮件验证
func
(
s
*
SettingService
)
IsEmailVerifyEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyEmailVerifyEnabled
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyEmailVerifyEnabled
)
if
err
!=
nil
{
return
false
}
...
...
@@ -148,7 +147,7 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
// GetSiteName 获取网站名称
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeySiteName
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
if
err
!=
nil
||
value
==
""
{
return
"Sub2API"
}
...
...
@@ -157,7 +156,7 @@ func (s *SettingService) GetSiteName(ctx context.Context) string {
// GetDefaultConcurrency 获取默认并发量
func
(
s
*
SettingService
)
GetDefaultConcurrency
(
ctx
context
.
Context
)
int
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyDefaultConcurrency
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyDefaultConcurrency
)
if
err
!=
nil
{
return
s
.
cfg
.
Default
.
UserConcurrency
}
...
...
@@ -169,7 +168,7 @@ func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
// GetDefaultBalance 获取默认余额
func
(
s
*
SettingService
)
GetDefaultBalance
(
ctx
context
.
Context
)
float64
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyDefaultBalance
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyDefaultBalance
)
if
err
!=
nil
{
return
s
.
cfg
.
Default
.
UserBalance
}
...
...
@@ -182,7 +181,7 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
// InitializeDefaultSettings 初始化默认设置
func
(
s
*
SettingService
)
InitializeDefaultSettings
(
ctx
context
.
Context
)
error
{
// 检查是否已有设置
_
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyRegistrationEnabled
)
_
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyRegistrationEnabled
)
if
err
==
nil
{
// 已有设置,不需要初始化
return
nil
...
...
@@ -193,62 +192,62 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults
:=
map
[
string
]
string
{
model
.
SettingKeyRegistrationEnabled
:
"true"
,
model
.
SettingKeyEmailVerifyEnabled
:
"false"
,
model
.
SettingKeySiteName
:
"Sub2API"
,
model
.
SettingKeySiteLogo
:
""
,
model
.
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
model
.
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
model
.
SettingKeySmtpPort
:
"587"
,
model
.
SettingKeySmtpUseTLS
:
"false"
,
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeySiteName
:
"Sub2API"
,
SettingKeySiteLogo
:
""
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeySmtpPort
:
"587"
,
SettingKeySmtpUseTLS
:
"false"
,
}
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
defaults
)
}
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
model
.
SystemSettings
{
result
:=
&
model
.
SystemSettings
{
RegistrationEnabled
:
settings
[
model
.
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
model
.
SettingKeyEmailVerifyEnabled
]
==
"true"
,
SmtpHost
:
settings
[
model
.
SettingKeySmtpHost
],
SmtpUsername
:
settings
[
model
.
SettingKeySmtpUsername
],
SmtpFrom
:
settings
[
model
.
SettingKeySmtpFrom
],
SmtpFromName
:
settings
[
model
.
SettingKeySmtpFromName
],
SmtpUseTLS
:
settings
[
model
.
SettingKeySmtpUseTLS
]
==
"true"
,
TurnstileEnabled
:
settings
[
model
.
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
model
.
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
model
.
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
model
.
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
model
.
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
ApiBaseUrl
:
settings
[
model
.
SettingKeyApiBaseUrl
],
ContactInfo
:
settings
[
model
.
SettingKeyContactInfo
],
DocUrl
:
settings
[
model
.
SettingKeyDocUrl
],
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
SmtpHost
:
settings
[
SettingKeySmtpHost
],
SmtpUsername
:
settings
[
SettingKeySmtpUsername
],
SmtpFrom
:
settings
[
SettingKeySmtpFrom
],
SmtpFromName
:
settings
[
SettingKeySmtpFromName
],
SmtpUseTLS
:
settings
[
SettingKeySmtpUseTLS
]
==
"true"
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
ApiBaseUrl
:
settings
[
SettingKeyApiBaseUrl
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocUrl
:
settings
[
SettingKeyDocUrl
],
}
// 解析整数类型
if
port
,
err
:=
strconv
.
Atoi
(
settings
[
model
.
SettingKeySmtpPort
]);
err
==
nil
{
if
port
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeySmtpPort
]);
err
==
nil
{
result
.
SmtpPort
=
port
}
else
{
result
.
SmtpPort
=
587
}
if
concurrency
,
err
:=
strconv
.
Atoi
(
settings
[
model
.
SettingKeyDefaultConcurrency
]);
err
==
nil
{
if
concurrency
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyDefaultConcurrency
]);
err
==
nil
{
result
.
DefaultConcurrency
=
concurrency
}
else
{
result
.
DefaultConcurrency
=
s
.
cfg
.
Default
.
UserConcurrency
}
// 解析浮点数类型
if
balance
,
err
:=
strconv
.
ParseFloat
(
settings
[
model
.
SettingKeyDefaultBalance
],
64
);
err
==
nil
{
if
balance
,
err
:=
strconv
.
ParseFloat
(
settings
[
SettingKeyDefaultBalance
],
64
);
err
==
nil
{
result
.
DefaultBalance
=
balance
}
else
{
result
.
DefaultBalance
=
s
.
cfg
.
Default
.
UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result
.
SmtpPassword
=
settings
[
model
.
SettingKeySmtpPassword
]
result
.
TurnstileSecretKey
=
settings
[
model
.
SettingKeyTurnstileSecretKey
]
result
.
SmtpPassword
=
settings
[
SettingKeySmtpPassword
]
result
.
TurnstileSecretKey
=
settings
[
SettingKeyTurnstileSecretKey
]
return
result
}
...
...
@@ -263,7 +262,7 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
func
(
s
*
SettingService
)
IsTurnstileEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyTurnstileEnabled
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyTurnstileEnabled
)
if
err
!=
nil
{
return
false
}
...
...
@@ -272,7 +271,7 @@ func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
// GetTurnstileSecretKey 获取 Turnstile Secret Key
func
(
s
*
SettingService
)
GetTurnstileSecretKey
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyTurnstileSecretKey
)
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyTurnstileSecretKey
)
if
err
!=
nil
{
return
""
}
...
...
@@ -287,10 +286,10 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
return
""
,
fmt
.
Errorf
(
"generate random bytes: %w"
,
err
)
}
key
:=
model
.
AdminApiKeyPrefix
+
hex
.
EncodeToString
(
bytes
)
key
:=
AdminApiKeyPrefix
+
hex
.
EncodeToString
(
bytes
)
// 存储到 settings 表
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
model
.
SettingKeyAdminApiKey
,
key
);
err
!=
nil
{
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyAdminApiKey
,
key
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"save admin api key: %w"
,
err
)
}
...
...
@@ -300,7 +299,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func
(
s
*
SettingService
)
GetAdminApiKeyStatus
(
ctx
context
.
Context
)
(
maskedKey
string
,
exists
bool
,
err
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyAdminApiKey
)
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminApiKey
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
false
,
nil
...
...
@@ -324,7 +323,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func
(
s
*
SettingService
)
GetAdminApiKey
(
ctx
context
.
Context
)
(
string
,
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
model
.
SettingKeyAdminApiKey
)
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminApiKey
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
nil
// 未配置,返回空字符串
...
...
@@ -336,5 +335,5 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
// DeleteAdminApiKey 删除管理员 API Key
func
(
s
*
SettingService
)
DeleteAdminApiKey
(
ctx
context
.
Context
)
error
{
return
s
.
settingRepo
.
Delete
(
ctx
,
model
.
SettingKeyAdminApiKey
)
return
s
.
settingRepo
.
Delete
(
ctx
,
SettingKeyAdminApiKey
)
}
backend/internal/service/settings_view.go
0 → 100644
View file @
e5a77853
package
service
type
SystemSettings
struct
{
RegistrationEnabled
bool
EmailVerifyEnabled
bool
SmtpHost
string
SmtpPort
int
SmtpUsername
string
SmtpPassword
string
SmtpFrom
string
SmtpFromName
string
SmtpUseTLS
bool
TurnstileEnabled
bool
TurnstileSiteKey
string
TurnstileSecretKey
string
SiteName
string
SiteLogo
string
SiteSubtitle
string
ApiBaseUrl
string
ContactInfo
string
DocUrl
string
DefaultConcurrency
int
DefaultBalance
float64
}
type
PublicSettings
struct
{
RegistrationEnabled
bool
EmailVerifyEnabled
bool
TurnstileEnabled
bool
TurnstileSiteKey
string
SiteName
string
SiteLogo
string
SiteSubtitle
string
ApiBaseUrl
string
ContactInfo
string
DocUrl
string
Version
string
}
backend/internal/service/subscription_service.go
View file @
e5a77853
...
...
@@ -7,7 +7,6 @@ import (
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
...
...
@@ -48,7 +47,7 @@ type AssignSubscriptionInput struct {
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func
(
s
*
SubscriptionService
)
AssignSubscription
(
ctx
context
.
Context
,
input
*
AssignSubscriptionInput
)
(
*
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
AssignSubscription
(
ctx
context
.
Context
,
input
*
AssignSubscriptionInput
)
(
*
UserSubscription
,
error
)
{
// 检查分组是否存在且为订阅类型
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
input
.
GroupID
)
if
err
!=
nil
{
...
...
@@ -91,7 +90,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅
func
(
s
*
SubscriptionService
)
AssignOrExtendSubscription
(
ctx
context
.
Context
,
input
*
AssignSubscriptionInput
)
(
*
model
.
UserSubscription
,
bool
,
error
)
{
func
(
s
*
SubscriptionService
)
AssignOrExtendSubscription
(
ctx
context
.
Context
,
input
*
AssignSubscriptionInput
)
(
*
UserSubscription
,
bool
,
error
)
{
// 检查分组是否存在且为订阅类型
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
input
.
GroupID
)
if
err
!=
nil
{
...
...
@@ -132,8 +131,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 如果订阅已过期或被暂停,恢复为active状态
if
existingSub
.
Status
!=
model
.
SubscriptionStatusActive
{
if
err
:=
s
.
userSubRepo
.
UpdateStatus
(
ctx
,
existingSub
.
ID
,
model
.
SubscriptionStatusActive
);
err
!=
nil
{
if
existingSub
.
Status
!=
SubscriptionStatusActive
{
if
err
:=
s
.
userSubRepo
.
UpdateStatus
(
ctx
,
existingSub
.
ID
,
SubscriptionStatusActive
);
err
!=
nil
{
return
nil
,
false
,
fmt
.
Errorf
(
"update subscription status: %w"
,
err
)
}
}
...
...
@@ -185,19 +184,19 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// createSubscription 创建新订阅(内部方法)
func
(
s
*
SubscriptionService
)
createSubscription
(
ctx
context
.
Context
,
input
*
AssignSubscriptionInput
)
(
*
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
createSubscription
(
ctx
context
.
Context
,
input
*
AssignSubscriptionInput
)
(
*
UserSubscription
,
error
)
{
validityDays
:=
input
.
ValidityDays
if
validityDays
<=
0
{
validityDays
=
30
}
now
:=
time
.
Now
()
sub
:=
&
model
.
UserSubscription
{
sub
:=
&
UserSubscription
{
UserID
:
input
.
UserID
,
GroupID
:
input
.
GroupID
,
StartsAt
:
now
,
ExpiresAt
:
now
.
AddDate
(
0
,
0
,
validityDays
),
Status
:
model
.
SubscriptionStatusActive
,
Status
:
SubscriptionStatusActive
,
AssignedAt
:
now
,
Notes
:
input
.
Notes
,
CreatedAt
:
now
,
...
...
@@ -229,14 +228,14 @@ type BulkAssignSubscriptionInput struct {
type
BulkAssignResult
struct
{
SuccessCount
int
FailedCount
int
Subscriptions
[]
model
.
UserSubscription
Subscriptions
[]
UserSubscription
Errors
[]
string
}
// BulkAssignSubscription 批量分配订阅
func
(
s
*
SubscriptionService
)
BulkAssignSubscription
(
ctx
context
.
Context
,
input
*
BulkAssignSubscriptionInput
)
(
*
BulkAssignResult
,
error
)
{
result
:=
&
BulkAssignResult
{
Subscriptions
:
make
([]
model
.
UserSubscription
,
0
),
Subscriptions
:
make
([]
UserSubscription
,
0
),
Errors
:
make
([]
string
,
0
),
}
...
...
@@ -286,7 +285,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// ExtendSubscription 延长订阅
func
(
s
*
SubscriptionService
)
ExtendSubscription
(
ctx
context
.
Context
,
subscriptionID
int64
,
days
int
)
(
*
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
ExtendSubscription
(
ctx
context
.
Context
,
subscriptionID
int64
,
days
int
)
(
*
UserSubscription
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetByID
(
ctx
,
subscriptionID
)
if
err
!=
nil
{
return
nil
,
ErrSubscriptionNotFound
...
...
@@ -299,8 +298,8 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 如果订阅已过期,恢复为active状态
if
sub
.
Status
==
model
.
SubscriptionStatusExpired
{
if
err
:=
s
.
userSubRepo
.
UpdateStatus
(
ctx
,
subscriptionID
,
model
.
SubscriptionStatusActive
);
err
!=
nil
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
if
err
:=
s
.
userSubRepo
.
UpdateStatus
(
ctx
,
subscriptionID
,
SubscriptionStatusActive
);
err
!=
nil
{
return
nil
,
err
}
}
...
...
@@ -319,12 +318,12 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// GetByID 根据ID获取订阅
func
(
s
*
SubscriptionService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
UserSubscription
,
error
)
{
return
s
.
userSubRepo
.
GetByID
(
ctx
,
id
)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func
(
s
*
SubscriptionService
)
GetActiveSubscription
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
GetActiveSubscription
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
UserSubscription
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetActiveByUserIDAndGroupID
(
ctx
,
userID
,
groupID
)
if
err
!=
nil
{
return
nil
,
ErrSubscriptionNotFound
...
...
@@ -333,7 +332,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
}
// ListUserSubscriptions 获取用户的所有订阅
func
(
s
*
SubscriptionService
)
ListUserSubscriptions
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
ListUserSubscriptions
(
ctx
context
.
Context
,
userID
int64
)
([]
UserSubscription
,
error
)
{
subs
,
err
:=
s
.
userSubRepo
.
ListByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -343,7 +342,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func
(
s
*
SubscriptionService
)
ListActiveUserSubscriptions
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
UserSubscription
,
error
)
{
func
(
s
*
SubscriptionService
)
ListActiveUserSubscriptions
(
ctx
context
.
Context
,
userID
int64
)
([]
UserSubscription
,
error
)
{
subs
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -353,7 +352,7 @@ func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, u
}
// ListGroupSubscriptions 获取分组的所有订阅
func
(
s
*
SubscriptionService
)
ListGroupSubscriptions
(
ctx
context
.
Context
,
groupID
int64
,
page
,
pageSize
int
)
([]
model
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
SubscriptionService
)
ListGroupSubscriptions
(
ctx
context
.
Context
,
groupID
int64
,
page
,
pageSize
int
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
subs
,
pag
,
err
:=
s
.
userSubRepo
.
ListByGroupID
(
ctx
,
groupID
,
params
)
if
err
!=
nil
{
...
...
@@ -364,7 +363,7 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
}
// List 获取所有订阅(分页,支持筛选)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
string
)
([]
model
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
)
if
err
!=
nil
{
...
...
@@ -376,7 +375,7 @@ func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, user
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
func
normalizeExpiredWindows
(
subs
[]
model
.
UserSubscription
)
{
func
normalizeExpiredWindows
(
subs
[]
UserSubscription
)
{
for
i
:=
range
subs
{
sub
:=
&
subs
[
i
]
// 日窗口过期:清零展示数据
...
...
@@ -403,7 +402,7 @@ func startOfDay(t time.Time) time.Time {
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func
(
s
*
SubscriptionService
)
CheckAndActivateWindow
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
{
func
(
s
*
SubscriptionService
)
CheckAndActivateWindow
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
if
sub
.
IsWindowActivated
()
{
return
nil
}
...
...
@@ -414,7 +413,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
}
// CheckAndResetWindows 检查并重置过期的窗口
func
(
s
*
SubscriptionService
)
CheckAndResetWindows
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
{
func
(
s
*
SubscriptionService
)
CheckAndResetWindows
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
// 使用当天零点作为新窗口起始时间
windowStart
:=
startOfDay
(
time
.
Now
())
needsInvalidateCache
:=
false
...
...
@@ -458,7 +457,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
func
(
s
*
SubscriptionService
)
CheckUsageLimits
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
,
group
*
model
.
Group
,
additionalCost
float64
)
error
{
func
(
s
*
SubscriptionService
)
CheckUsageLimits
(
ctx
context
.
Context
,
sub
*
UserSubscription
,
group
*
Group
,
additionalCost
float64
)
error
{
if
!
sub
.
CheckDailyLimit
(
group
,
additionalCost
)
{
return
ErrDailyLimitExceeded
}
...
...
@@ -620,16 +619,16 @@ func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (i
}
// ValidateSubscription 验证订阅是否有效
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
{
if
sub
.
Status
==
model
.
SubscriptionStatusExpired
{
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
return
ErrSubscriptionExpired
}
if
sub
.
Status
==
model
.
SubscriptionStatusSuspended
{
if
sub
.
Status
==
SubscriptionStatusSuspended
{
return
ErrSubscriptionSuspended
}
if
sub
.
IsExpired
()
{
// 更新状态
_
=
s
.
userSubRepo
.
UpdateStatus
(
ctx
,
sub
.
ID
,
model
.
SubscriptionStatusExpired
)
_
=
s
.
userSubRepo
.
UpdateStatus
(
ctx
,
sub
.
ID
,
SubscriptionStatusExpired
)
return
ErrSubscriptionExpired
}
return
nil
...
...
backend/internal/service/token_refresh_service.go
View file @
e5a77853
...
...
@@ -8,7 +8,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// TokenRefreshService OAuth token自动刷新服务
...
...
@@ -142,19 +141,19 @@ func (s *TokenRefreshService) processRefresh() {
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
func
(
s
*
TokenRefreshService
)
listActiveAccounts
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
func
(
s
*
TokenRefreshService
)
listActiveAccounts
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
s
.
accountRepo
.
ListActive
(
ctx
)
}
// refreshWithRetry 带重试的刷新
func
(
s
*
TokenRefreshService
)
refreshWithRetry
(
ctx
context
.
Context
,
account
*
model
.
Account
,
refresher
TokenRefresher
)
error
{
func
(
s
*
TokenRefreshService
)
refreshWithRetry
(
ctx
context
.
Context
,
account
*
Account
,
refresher
TokenRefresher
)
error
{
var
lastErr
error
for
attempt
:=
1
;
attempt
<=
s
.
cfg
.
MaxRetries
;
attempt
++
{
newCredentials
,
err
:=
refresher
.
Refresh
(
ctx
,
account
)
if
err
==
nil
{
// 刷新成功,更新账号credentials
account
.
Credentials
=
model
.
JSONB
(
newCredentials
)
account
.
Credentials
=
newCredentials
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
}
...
...
backend/internal/service/token_refresher.go
View file @
e5a77853
...
...
@@ -4,22 +4,20 @@ import (
"context"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
type
TokenRefresher
interface
{
// CanRefresh 检查此刷新器是否能处理指定账号
CanRefresh
(
account
*
model
.
Account
)
bool
CanRefresh
(
account
*
Account
)
bool
// NeedsRefresh 检查账号的token是否需要刷新
NeedsRefresh
(
account
*
model
.
Account
,
refreshWindow
time
.
Duration
)
bool
NeedsRefresh
(
account
*
Account
,
refreshWindow
time
.
Duration
)
bool
// Refresh 执行token刷新,返回更新后的credentials
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
Refresh
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
map
[
string
]
any
,
error
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
}
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
...
...
@@ -37,14 +35,14 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
// CanRefresh 检查是否能处理此账号
// 只处理 anthropic 平台的 oauth 类型账号
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
func
(
r
*
ClaudeTokenRefresher
)
CanRefresh
(
account
*
model
.
Account
)
bool
{
return
account
.
Platform
==
model
.
PlatformAnthropic
&&
account
.
Type
==
model
.
AccountTypeOAuth
func
(
r
*
ClaudeTokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
==
AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func
(
r
*
ClaudeTokenRefresher
)
NeedsRefresh
(
account
*
model
.
Account
,
refreshWindow
time
.
Duration
)
bool
{
func
(
r
*
ClaudeTokenRefresher
)
NeedsRefresh
(
account
*
Account
,
refreshWindow
time
.
Duration
)
bool
{
expiresAtStr
:=
account
.
GetCredential
(
"expires_at"
)
if
expiresAtStr
==
""
{
return
false
...
...
@@ -61,7 +59,7 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
// Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段
func
(
r
*
ClaudeTokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
map
[
string
]
any
,
error
)
{
func
(
r
*
ClaudeTokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
tokenInfo
,
err
:=
r
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -103,14 +101,14 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAIToke
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号
func
(
r
*
OpenAITokenRefresher
)
CanRefresh
(
account
*
model
.
Account
)
bool
{
return
account
.
Platform
==
model
.
PlatformOpenAI
&&
account
.
Type
==
model
.
AccountTypeOAuth
func
(
r
*
OpenAITokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformOpenAI
&&
account
.
Type
==
AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func
(
r
*
OpenAITokenRefresher
)
NeedsRefresh
(
account
*
model
.
Account
,
refreshWindow
time
.
Duration
)
bool
{
func
(
r
*
OpenAITokenRefresher
)
NeedsRefresh
(
account
*
Account
,
refreshWindow
time
.
Duration
)
bool
{
expiresAt
:=
account
.
GetOpenAITokenExpiresAt
()
if
expiresAt
==
nil
{
return
false
...
...
@@ -121,7 +119,7 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
// Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段
func
(
r
*
OpenAITokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
map
[
string
]
any
,
error
)
{
func
(
r
*
OpenAITokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
tokenInfo
,
err
:=
r
.
openaiOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
...
...
backend/internal/service/usage_log.go
0 → 100644
View file @
e5a77853
package
service
import
"time"
const
(
BillingTypeBalance
int8
=
0
// 钱包余额
BillingTypeSubscription
int8
=
1
// 订阅套餐
)
type
UsageLog
struct
{
ID
int64
UserID
int64
ApiKeyID
int64
AccountID
int64
RequestID
string
Model
string
GroupID
*
int64
SubscriptionID
*
int64
InputTokens
int
OutputTokens
int
CacheCreationTokens
int
CacheReadTokens
int
CacheCreation5mTokens
int
CacheCreation1hTokens
int
InputCost
float64
OutputCost
float64
CacheCreationCost
float64
CacheReadCost
float64
TotalCost
float64
ActualCost
float64
RateMultiplier
float64
BillingType
int8
Stream
bool
DurationMs
*
int
FirstTokenMs
*
int
CreatedAt
time
.
Time
User
*
User
ApiKey
*
ApiKey
Account
*
Account
Group
*
Group
Subscription
*
UserSubscription
}
func
(
u
*
UsageLog
)
TotalTokens
()
int
{
return
u
.
InputTokens
+
u
.
OutputTokens
+
u
.
CacheCreationTokens
+
u
.
CacheReadTokens
}
backend/internal/service/usage_service.go
View file @
e5a77853
...
...
@@ -6,7 +6,6 @@ import (
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
...
...
@@ -66,7 +65,7 @@ func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *Usa
}
// Create 创建使用日志
func
(
s
*
UsageService
)
Create
(
ctx
context
.
Context
,
req
CreateUsageLogRequest
)
(
*
model
.
UsageLog
,
error
)
{
func
(
s
*
UsageService
)
Create
(
ctx
context
.
Context
,
req
CreateUsageLogRequest
)
(
*
UsageLog
,
error
)
{
// 验证用户存在
_
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
req
.
UserID
)
if
err
!=
nil
{
...
...
@@ -74,7 +73,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// 创建使用日志
usageLog
:=
&
model
.
UsageLog
{
usageLog
:=
&
UsageLog
{
UserID
:
req
.
UserID
,
ApiKeyID
:
req
.
ApiKeyID
,
AccountID
:
req
.
AccountID
,
...
...
@@ -112,7 +111,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// GetByID 根据ID获取使用日志
func
(
s
*
UsageService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UsageLog
,
error
)
{
func
(
s
*
UsageService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
UsageLog
,
error
)
{
log
,
err
:=
s
.
usageRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get usage log: %w"
,
err
)
...
...
@@ -121,7 +120,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
}
// ListByUser 获取用户的使用日志列表
func
(
s
*
UsageService
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
UsageService
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByUser
(
ctx
,
userID
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs: %w"
,
err
)
...
...
@@ -130,7 +129,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
}
// ListByApiKey 获取API Key的使用日志列表
func
(
s
*
UsageService
)
ListByApiKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
UsageService
)
ListByApiKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByApiKey
(
ctx
,
apiKeyID
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs: %w"
,
err
)
...
...
@@ -139,7 +138,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
}
// ListByAccount 获取账号的使用日志列表
func
(
s
*
UsageService
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
UsageService
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByAccount
(
ctx
,
accountID
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs: %w"
,
err
)
...
...
@@ -243,7 +242,7 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
}
// calculateStats 计算统计数据
func
(
s
*
UsageService
)
calculateStats
(
logs
[]
model
.
UsageLog
)
*
UsageStats
{
func
(
s
*
UsageService
)
calculateStats
(
logs
[]
UsageLog
)
*
UsageStats
{
stats
:=
&
UsageStats
{}
for
_
,
log
:=
range
logs
{
...
...
@@ -313,7 +312,7 @@ func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs [
}
// ListWithFilters lists usage logs with admin filters.
func
(
s
*
UsageService
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
UsageService
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
result
,
err
:=
s
.
usageRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs with filters: %w"
,
err
)
...
...
backend/internal/service/user.go
0 → 100644
View file @
e5a77853
package
service
import
(
"time"
"golang.org/x/crypto/bcrypt"
)
type
User
struct
{
ID
int64
Email
string
Username
string
Wechat
string
Notes
string
PasswordHash
string
Role
string
Balance
float64
Concurrency
int
Status
string
AllowedGroups
[]
int64
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
ApiKeys
[]
ApiKey
Subscriptions
[]
UserSubscription
}
func
(
u
*
User
)
IsAdmin
()
bool
{
return
u
.
Role
==
RoleAdmin
}
func
(
u
*
User
)
IsActive
()
bool
{
return
u
.
Status
==
StatusActive
}
// CanBindGroup checks whether a user can bind to a given group.
// For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
func
(
u
*
User
)
CanBindGroup
(
groupID
int64
,
isExclusive
bool
)
bool
{
if
len
(
u
.
AllowedGroups
)
>
0
{
for
_
,
id
:=
range
u
.
AllowedGroups
{
if
id
==
groupID
{
return
true
}
}
return
false
}
return
!
isExclusive
}
func
(
u
*
User
)
SetPassword
(
password
string
)
error
{
hash
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
password
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
err
}
u
.
PasswordHash
=
string
(
hash
)
return
nil
}
func
(
u
*
User
)
CheckPassword
(
password
string
)
bool
{
return
bcrypt
.
CompareHashAndPassword
([]
byte
(
u
.
PasswordHash
),
[]
byte
(
password
))
==
nil
}
backend/internal/service/user_service.go
View file @
e5a77853
...
...
@@ -5,9 +5,7 @@ import (
"fmt"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt"
)
var
(
...
...
@@ -17,15 +15,15 @@ var (
)
type
UserRepository
interface
{
Create
(
ctx
context
.
Context
,
user
*
model
.
User
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
User
,
error
)
GetByEmail
(
ctx
context
.
Context
,
email
string
)
(
*
model
.
User
,
error
)
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
model
.
User
,
error
)
Update
(
ctx
context
.
Context
,
user
*
model
.
User
)
error
Create
(
ctx
context
.
Context
,
user
*
User
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
GetByEmail
(
ctx
context
.
Context
,
email
string
)
(
*
User
,
error
)
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
User
,
error
)
Update
(
ctx
context
.
Context
,
user
*
User
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
User
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
role
,
search
string
)
([]
model
.
User
,
*
pagination
.
PaginationResult
,
error
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
role
,
search
string
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
...
...
@@ -61,7 +59,7 @@ func NewUserService(userRepo UserRepository) *UserService {
}
// GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证)
func
(
s
*
UserService
)
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
model
.
User
,
error
)
{
func
(
s
*
UserService
)
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
User
,
error
)
{
admin
,
err
:=
s
.
userRepo
.
GetFirstAdmin
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get first admin: %w"
,
err
)
...
...
@@ -70,7 +68,7 @@ func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
}
// GetProfile 获取用户资料
func
(
s
*
UserService
)
GetProfile
(
ctx
context
.
Context
,
userID
int64
)
(
*
model
.
User
,
error
)
{
func
(
s
*
UserService
)
GetProfile
(
ctx
context
.
Context
,
userID
int64
)
(
*
User
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
...
...
@@ -79,7 +77,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
}
// UpdateProfile 更新用户资料
func
(
s
*
UserService
)
UpdateProfile
(
ctx
context
.
Context
,
userID
int64
,
req
UpdateProfileRequest
)
(
*
model
.
User
,
error
)
{
func
(
s
*
UserService
)
UpdateProfile
(
ctx
context
.
Context
,
userID
int64
,
req
UpdateProfileRequest
)
(
*
User
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
...
...
@@ -125,18 +123,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
}
// 验证当前密码
if
err
:=
bcrypt
.
CompareHashAndPassword
([]
byte
(
user
.
PasswordHash
),
[]
byte
(
req
.
CurrentPassword
)
);
err
!=
nil
{
if
!
user
.
CheckPassword
(
req
.
CurrentPassword
)
{
return
ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
req
.
NewPassword
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"hash password: %w"
,
err
)
if
err
:=
user
.
SetPassword
(
req
.
NewPassword
);
err
!=
nil
{
return
fmt
.
Errorf
(
"set password: %w"
,
err
)
}
user
.
PasswordHash
=
string
(
hashedPassword
)
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update user: %w"
,
err
)
}
...
...
@@ -145,7 +139,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
}
// GetByID 根据ID获取用户(管理员功能)
func
(
s
*
UserService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
User
,
error
)
{
func
(
s
*
UserService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
...
...
@@ -154,7 +148,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
}
// List 获取用户列表(管理员功能)
func
(
s
*
UserService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
UserService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
users
,
pagination
,
err
:=
s
.
userRepo
.
List
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list users: %w"
,
err
)
...
...
backend/internal/
model
/user_subscription.go
→
backend/internal/
service
/user_subscription.go
View file @
e5a77853
package
model
package
service
import
(
"time"
)
import
"time"
// 订阅状态常量
const
(
SubscriptionStatusActive
=
"active"
SubscriptionStatusExpired
=
"expired"
SubscriptionStatusSuspended
=
"suspended"
)
// UserSubscription 用户订阅模型
type
UserSubscription
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
GroupID
int64
`gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt
time
.
Time
`gorm:"not null" json:"starts_at"`
ExpiresAt
time
.
Time
`gorm:"not null" json:"expires_at"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/expired/suspended
// 滑动窗口起始时间(nil = 未激活)
DailyWindowStart
*
time
.
Time
`json:"daily_window_start"`
WeeklyWindowStart
*
time
.
Time
`json:"weekly_window_start"`
MonthlyWindowStart
*
time
.
Time
`json:"monthly_window_start"`
// 当前窗口已用额度(USD,基于 total_cost 计算)
DailyUsageUSD
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy
*
int64
`gorm:"index" json:"assigned_by"`
AssignedAt
time
.
Time
`gorm:"not null" json:"assigned_at"`
Notes
string
`gorm:"type:text" json:"notes"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser
*
User
`gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
ID
int64
UserID
int64
GroupID
int64
StartsAt
time
.
Time
ExpiresAt
time
.
Time
Status
string
DailyWindowStart
*
time
.
Time
WeeklyWindowStart
*
time
.
Time
MonthlyWindowStart
*
time
.
Time
DailyUsageUSD
float64
WeeklyUsageUSD
float64
MonthlyUsageUSD
float64
AssignedBy
*
int64
AssignedAt
time
.
Time
Notes
string
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
func
(
UserSubscription
)
TableName
()
string
{
return
"user_subscriptions"
User
*
User
Group
*
Group
AssignedByUser
*
User
}
// IsActive 检查订阅是否有效(状态为active且未过期)
func
(
s
*
UserSubscription
)
IsActive
()
bool
{
return
s
.
Status
==
SubscriptionStatusActive
&&
time
.
Now
()
.
Before
(
s
.
ExpiresAt
)
}
// IsExpired 检查订阅是否已过期
func
(
s
*
UserSubscription
)
IsExpired
()
bool
{
return
time
.
Now
()
.
After
(
s
.
ExpiresAt
)
}
// DaysRemaining 返回订阅剩余天数
func
(
s
*
UserSubscription
)
DaysRemaining
()
int
{
if
s
.
IsExpired
()
{
return
0
...
...
@@ -68,12 +46,10 @@ func (s *UserSubscription) DaysRemaining() int {
return
int
(
time
.
Until
(
s
.
ExpiresAt
)
.
Hours
()
/
24
)
}
// IsWindowActivated 检查窗口是否已激活
func
(
s
*
UserSubscription
)
IsWindowActivated
()
bool
{
return
s
.
DailyWindowStart
!=
nil
||
s
.
WeeklyWindowStart
!=
nil
||
s
.
MonthlyWindowStart
!=
nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func
(
s
*
UserSubscription
)
NeedsDailyReset
()
bool
{
if
s
.
DailyWindowStart
==
nil
{
return
false
...
...
@@ -81,7 +57,6 @@ func (s *UserSubscription) NeedsDailyReset() bool {
return
time
.
Since
(
*
s
.
DailyWindowStart
)
>=
24
*
time
.
Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func
(
s
*
UserSubscription
)
NeedsWeeklyReset
()
bool
{
if
s
.
WeeklyWindowStart
==
nil
{
return
false
...
...
@@ -89,7 +64,6 @@ func (s *UserSubscription) NeedsWeeklyReset() bool {
return
time
.
Since
(
*
s
.
WeeklyWindowStart
)
>=
7
*
24
*
time
.
Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func
(
s
*
UserSubscription
)
NeedsMonthlyReset
()
bool
{
if
s
.
MonthlyWindowStart
==
nil
{
return
false
...
...
@@ -97,7 +71,6 @@ func (s *UserSubscription) NeedsMonthlyReset() bool {
return
time
.
Since
(
*
s
.
MonthlyWindowStart
)
>=
30
*
24
*
time
.
Hour
}
// DailyResetTime 返回日窗口重置时间
func
(
s
*
UserSubscription
)
DailyResetTime
()
*
time
.
Time
{
if
s
.
DailyWindowStart
==
nil
{
return
nil
...
...
@@ -106,7 +79,6 @@ func (s *UserSubscription) DailyResetTime() *time.Time {
return
&
t
}
// WeeklyResetTime 返回周窗口重置时间
func
(
s
*
UserSubscription
)
WeeklyResetTime
()
*
time
.
Time
{
if
s
.
WeeklyWindowStart
==
nil
{
return
nil
...
...
@@ -115,7 +87,6 @@ func (s *UserSubscription) WeeklyResetTime() *time.Time {
return
&
t
}
// MonthlyResetTime 返回月窗口重置时间
func
(
s
*
UserSubscription
)
MonthlyResetTime
()
*
time
.
Time
{
if
s
.
MonthlyWindowStart
==
nil
{
return
nil
...
...
@@ -124,31 +95,27 @@ func (s *UserSubscription) MonthlyResetTime() *time.Time {
return
&
t
}
// CheckDailyLimit 检查是否超出日限额
func
(
s
*
UserSubscription
)
CheckDailyLimit
(
group
*
Group
,
additionalCost
float64
)
bool
{
if
!
group
.
HasDailyLimit
()
{
return
true
// 无限制
return
true
}
return
s
.
DailyUsageUSD
+
additionalCost
<=
*
group
.
DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func
(
s
*
UserSubscription
)
CheckWeeklyLimit
(
group
*
Group
,
additionalCost
float64
)
bool
{
if
!
group
.
HasWeeklyLimit
()
{
return
true
// 无限制
return
true
}
return
s
.
WeeklyUsageUSD
+
additionalCost
<=
*
group
.
WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func
(
s
*
UserSubscription
)
CheckMonthlyLimit
(
group
*
Group
,
additionalCost
float64
)
bool
{
if
!
group
.
HasMonthlyLimit
()
{
return
true
// 无限制
return
true
}
return
s
.
MonthlyUsageUSD
+
additionalCost
<=
*
group
.
MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func
(
s
*
UserSubscription
)
CheckAllLimits
(
group
*
Group
,
additionalCost
float64
)
(
daily
,
weekly
,
monthly
bool
)
{
daily
=
s
.
CheckDailyLimit
(
group
,
additionalCost
)
weekly
=
s
.
CheckWeeklyLimit
(
group
,
additionalCost
)
...
...
backend/internal/service/user_subscription_port.go
View file @
e5a77853
...
...
@@ -4,22 +4,21 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type
UserSubscriptionRepository
interface
{
Create
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UserSubscription
,
error
)
GetByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
model
.
UserSubscription
,
error
)
GetActiveByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
model
.
UserSubscription
,
error
)
Update
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
Create
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
UserSubscription
,
error
)
GetByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
UserSubscription
,
error
)
GetActiveByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
UserSubscription
,
error
)
Update
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
UserSubscription
,
error
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
UserSubscription
,
error
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
model
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
UserSubscription
,
error
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
UserSubscription
,
error
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
ExtendExpiry
(
ctx
context
.
Context
,
subscriptionID
int64
,
newExpiresAt
time
.
Time
)
error
...
...
backend/internal/setup/setup.go
View file @
e5a77853
...
...
@@ -10,10 +10,10 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
"gorm.io/driver/postgres"
"gorm.io/gorm"
...
...
@@ -271,8 +271,7 @@ func initializeDatabase(cfg *SetupConfig) error {
}
}()
// 使用 model 包的 AutoMigrate,确保模型定义统一
return
model
.
AutoMigrate
(
db
)
return
repository
.
AutoMigrate
(
db
)
}
func
createAdminUser
(
cfg
*
SetupConfig
)
error
{
...
...
@@ -299,29 +298,28 @@ func createAdminUser(cfg *SetupConfig) error {
// Check if admin already exists
var
count
int64
db
.
Model
(
&
model
.
User
{})
.
Where
(
"role = ?"
,
"admin"
)
.
Count
(
&
count
)
if
err
:=
db
.
Table
(
"users"
)
.
Where
(
"role = ?"
,
service
.
RoleAdmin
)
.
Count
(
&
count
)
.
Error
;
err
!=
nil
{
return
err
}
if
count
>
0
{
return
nil
// Admin already exists
}
// Hash password
hashedPassword
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
cfg
.
Admin
.
Password
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
err
admin
:=
&
service
.
User
{
Email
:
cfg
.
Admin
.
Email
,
Role
:
service
.
RoleAdmin
,
Status
:
service
.
StatusActive
,
Balance
:
0
,
Concurrency
:
5
,
CreatedAt
:
time
.
Now
(),
UpdatedAt
:
time
.
Now
(),
}
// Create admin user
admin
:=
&
model
.
User
{
Email
:
cfg
.
Admin
.
Email
,
PasswordHash
:
string
(
hashedPassword
),
Role
:
model
.
RoleAdmin
,
Status
:
model
.
StatusActive
,
Balance
:
0
,
CreatedAt
:
time
.
Now
(),
UpdatedAt
:
time
.
Now
(),
if
err
:=
admin
.
SetPassword
(
cfg
.
Admin
.
Password
);
err
!=
nil
{
return
err
}
return
db
.
Create
(
admin
)
.
Error
return
repository
.
NewUserRepository
(
db
)
.
Create
(
context
.
Background
(),
admin
)
}
func
writeConfigFile
(
cfg
*
SetupConfig
)
error
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment