Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c7abfe67
Commit
c7abfe67
authored
Jan 08, 2026
by
song
Browse files
Merge remote-tracking branch 'upstream/main'
parents
4e3476a6
db6f53e2
Changes
99
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/openai_gateway_handler.go
View file @
c7abfe67
...
...
@@ -242,7 +242,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// Async record usage
go
func
(
result
*
service
.
OpenAIForwardResult
,
usedAccount
*
service
.
Account
)
{
go
func
(
result
*
service
.
OpenAIForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
OpenAIRecordUsageInput
{
...
...
@@ -251,10 +251,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User
:
apiKey
.
User
,
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
)
}(
result
,
account
,
userAgent
)
return
}
}
...
...
backend/internal/handler/usage_handler.go
View file @
c7abfe67
...
...
@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse date range
var
startTime
,
endTime
*
time
.
Time
userTZ
:=
c
.
Query
(
"timezone"
)
// Get user's timezone from request
if
startDateStr
:=
c
.
Query
(
"start_date"
);
startDateStr
!=
""
{
t
,
err
:=
timezone
.
ParseInLocation
(
"2006-01-02"
,
startDateStr
)
t
,
err
:=
timezone
.
ParseIn
User
Location
(
"2006-01-02"
,
startDateStr
,
userTZ
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid start_date format, use YYYY-MM-DD"
)
return
...
...
@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) {
}
if
endDateStr
:=
c
.
Query
(
"end_date"
);
endDateStr
!=
""
{
t
,
err
:=
timezone
.
ParseInLocation
(
"2006-01-02"
,
endDateStr
)
t
,
err
:=
timezone
.
ParseIn
User
Location
(
"2006-01-02"
,
endDateStr
,
userTZ
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid end_date format, use YYYY-MM-DD"
)
return
...
...
@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
// 获取时间范围参数
now
:=
timezone
.
Now
()
userTZ
:=
c
.
Query
(
"timezone"
)
// Get user's timezone from request
now
:=
timezone
.
NowInUserLocation
(
userTZ
)
var
startTime
,
endTime
time
.
Time
// 优先使用 start_date 和 end_date 参数
...
...
@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if
startDateStr
!=
""
&&
endDateStr
!=
""
{
// 使用自定义日期范围
var
err
error
startTime
,
err
=
timezone
.
ParseInLocation
(
"2006-01-02"
,
startDateStr
)
startTime
,
err
=
timezone
.
ParseIn
User
Location
(
"2006-01-02"
,
startDateStr
,
userTZ
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid start_date format, use YYYY-MM-DD"
)
return
}
endTime
,
err
=
timezone
.
ParseInLocation
(
"2006-01-02"
,
endDateStr
)
endTime
,
err
=
timezone
.
ParseIn
User
Location
(
"2006-01-02"
,
endDateStr
,
userTZ
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid end_date format, use YYYY-MM-DD"
)
return
...
...
@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
period
:=
c
.
DefaultQuery
(
"period"
,
"today"
)
switch
period
{
case
"today"
:
startTime
=
timezone
.
StartOfDay
(
now
)
startTime
=
timezone
.
StartOfDay
InUserLocation
(
now
,
userTZ
)
case
"week"
:
startTime
=
now
.
AddDate
(
0
,
0
,
-
7
)
case
"month"
:
startTime
=
now
.
AddDate
(
0
,
-
1
,
0
)
default
:
startTime
=
timezone
.
StartOfDay
(
now
)
startTime
=
timezone
.
StartOfDay
InUserLocation
(
now
,
userTZ
)
}
endTime
=
now
}
...
...
@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
// Uses user's timezone if provided, otherwise falls back to server timezone
func
parseUserTimeRange
(
c
*
gin
.
Context
)
(
time
.
Time
,
time
.
Time
)
{
now
:=
timezone
.
Now
()
userTZ
:=
c
.
Query
(
"timezone"
)
// Get user's timezone from request
now
:=
timezone
.
NowInUserLocation
(
userTZ
)
startDate
:=
c
.
Query
(
"start_date"
)
endDate
:=
c
.
Query
(
"end_date"
)
var
startTime
,
endTime
time
.
Time
if
startDate
!=
""
{
if
t
,
err
:=
timezone
.
ParseInLocation
(
"2006-01-02"
,
startDate
);
err
==
nil
{
if
t
,
err
:=
timezone
.
ParseIn
User
Location
(
"2006-01-02"
,
startDate
,
userTZ
);
err
==
nil
{
startTime
=
t
}
else
{
startTime
=
timezone
.
StartOfDay
(
now
.
AddDate
(
0
,
0
,
-
7
))
startTime
=
timezone
.
StartOfDay
InUserLocation
(
now
.
AddDate
(
0
,
0
,
-
7
)
,
userTZ
)
}
}
else
{
startTime
=
timezone
.
StartOfDay
(
now
.
AddDate
(
0
,
0
,
-
7
))
startTime
=
timezone
.
StartOfDay
InUserLocation
(
now
.
AddDate
(
0
,
0
,
-
7
)
,
userTZ
)
}
if
endDate
!=
""
{
if
t
,
err
:=
timezone
.
ParseInLocation
(
"2006-01-02"
,
endDate
);
err
==
nil
{
if
t
,
err
:=
timezone
.
ParseIn
User
Location
(
"2006-01-02"
,
endDate
,
userTZ
);
err
==
nil
{
endTime
=
t
.
Add
(
24
*
time
.
Hour
)
// Include the end date
}
else
{
endTime
=
timezone
.
StartOfDay
(
now
.
AddDate
(
0
,
0
,
1
))
endTime
=
timezone
.
StartOfDay
InUserLocation
(
now
.
AddDate
(
0
,
0
,
1
)
,
userTZ
)
}
}
else
{
endTime
=
timezone
.
StartOfDay
(
now
.
AddDate
(
0
,
0
,
1
))
endTime
=
timezone
.
StartOfDay
InUserLocation
(
now
.
AddDate
(
0
,
0
,
1
)
,
userTZ
)
}
return
startTime
,
endTime
...
...
backend/internal/pkg/httpclient/pool.go
View file @
c7abfe67
...
...
@@ -16,7 +16,6 @@
package
httpclient
import
(
"crypto/tls"
"fmt"
"net/http"
"net/url"
...
...
@@ -40,7 +39,7 @@ type Options struct {
ProxyURL
string
// 代理 URL(支持 http/https/socks5/socks5h)
Timeout
time
.
Duration
// 请求总超时时间
ResponseHeaderTimeout
time
.
Duration
// 等待响应头超时时间
InsecureSkipVerify
bool
// 是否跳过 TLS 证书验证
InsecureSkipVerify
bool
// 是否跳过 TLS 证书验证
(已禁用,不允许设置为 true)
ProxyStrict
bool
// 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP
bool
// 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts
bool
// 允许私有地址解析(与 ValidateResolvedIP 一起使用)
...
...
@@ -113,7 +112,8 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
if
opts
.
InsecureSkipVerify
{
transport
.
TLSClientConfig
=
&
tls
.
Config
{
InsecureSkipVerify
:
true
}
// 安全要求:禁止跳过证书验证,避免中间人攻击。
return
nil
,
fmt
.
Errorf
(
"insecure_skip_verify is not allowed; install a trusted certificate instead"
)
}
proxyURL
:=
strings
.
TrimSpace
(
opts
.
ProxyURL
)
...
...
backend/internal/pkg/timezone/timezone.go
View file @
c7abfe67
...
...
@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time {
func
ParseInLocation
(
layout
,
value
string
)
(
time
.
Time
,
error
)
{
return
time
.
ParseInLocation
(
layout
,
value
,
Location
())
}
// ParseInUserLocation parses a time string in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func
ParseInUserLocation
(
layout
,
value
,
userTZ
string
)
(
time
.
Time
,
error
)
{
loc
:=
Location
()
// default to server timezone
if
userTZ
!=
""
{
if
userLoc
,
err
:=
time
.
LoadLocation
(
userTZ
);
err
==
nil
{
loc
=
userLoc
}
}
return
time
.
ParseInLocation
(
layout
,
value
,
loc
)
}
// NowInUserLocation returns the current time in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func
NowInUserLocation
(
userTZ
string
)
time
.
Time
{
if
userTZ
==
""
{
return
Now
()
}
if
userLoc
,
err
:=
time
.
LoadLocation
(
userTZ
);
err
==
nil
{
return
time
.
Now
()
.
In
(
userLoc
)
}
return
Now
()
}
// StartOfDayInUserLocation returns the start of the given day in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func
StartOfDayInUserLocation
(
t
time
.
Time
,
userTZ
string
)
time
.
Time
{
loc
:=
Location
()
if
userTZ
!=
""
{
if
userLoc
,
err
:=
time
.
LoadLocation
(
userTZ
);
err
==
nil
{
loc
=
userLoc
}
}
t
=
t
.
In
(
loc
)
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
loc
)
}
backend/internal/repository/account_repo.go
View file @
c7abfe67
...
...
@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetPriority
(
account
.
Priority
)
.
SetStatus
(
account
.
Status
)
.
SetErrorMessage
(
account
.
ErrorMessage
)
.
SetSchedulable
(
account
.
Schedulable
)
SetSchedulable
(
account
.
Schedulable
)
.
SetAutoPauseOnExpired
(
account
.
AutoPauseOnExpired
)
if
account
.
ProxyID
!=
nil
{
builder
.
SetProxyID
(
*
account
.
ProxyID
)
...
...
@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
if
account
.
LastUsedAt
!=
nil
{
builder
.
SetLastUsedAt
(
*
account
.
LastUsedAt
)
}
if
account
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
account
.
ExpiresAt
)
}
if
account
.
RateLimitedAt
!=
nil
{
builder
.
SetRateLimitedAt
(
*
account
.
RateLimitedAt
)
}
...
...
@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetPriority
(
account
.
Priority
)
.
SetStatus
(
account
.
Status
)
.
SetErrorMessage
(
account
.
ErrorMessage
)
.
SetSchedulable
(
account
.
Schedulable
)
SetSchedulable
(
account
.
Schedulable
)
.
SetAutoPauseOnExpired
(
account
.
AutoPauseOnExpired
)
if
account
.
ProxyID
!=
nil
{
builder
.
SetProxyID
(
*
account
.
ProxyID
)
...
...
@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
else
{
builder
.
ClearLastUsedAt
()
}
if
account
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
account
.
ExpiresAt
)
}
else
{
builder
.
ClearExpiresAt
()
}
if
account
.
RateLimitedAt
!=
nil
{
builder
.
SetRateLimitedAt
(
*
account
.
RateLimitedAt
)
}
else
{
...
...
@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
SchedulableEQ
(
true
),
tempUnschedulablePredicate
(),
notExpiredPredicate
(
now
),
dbaccount
.
Or
(
dbaccount
.
OverloadUntilIsNil
(),
dbaccount
.
OverloadUntilLTE
(
now
)),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
now
)),
)
.
...
...
@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
SchedulableEQ
(
true
),
tempUnschedulablePredicate
(),
notExpiredPredicate
(
now
),
dbaccount
.
Or
(
dbaccount
.
OverloadUntilIsNil
(),
dbaccount
.
OverloadUntilLTE
(
now
)),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
now
)),
)
.
...
...
@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
SchedulableEQ
(
true
),
tempUnschedulablePredicate
(),
notExpiredPredicate
(
now
),
dbaccount
.
Or
(
dbaccount
.
OverloadUntilIsNil
(),
dbaccount
.
OverloadUntilLTE
(
now
)),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
now
)),
)
.
...
...
@@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return
err
}
func
(
r
*
accountRepository
)
AutoPauseExpiredAccounts
(
ctx
context
.
Context
,
now
time
.
Time
)
(
int64
,
error
)
{
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE accounts
SET schedulable = FALSE,
updated_at = NOW()
WHERE deleted_at IS NULL
AND schedulable = TRUE
AND auto_pause_on_expired = TRUE
AND expires_at IS NOT NULL
AND expires_at <= $1
`
,
now
)
if
err
!=
nil
{
return
0
,
err
}
rows
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
0
,
err
}
return
rows
,
nil
}
func
(
r
*
accountRepository
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
if
len
(
updates
)
==
0
{
return
nil
...
...
@@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
preds
=
append
(
preds
,
dbaccount
.
SchedulableEQ
(
true
),
tempUnschedulablePredicate
(),
notExpiredPredicate
(
now
),
dbaccount
.
Or
(
dbaccount
.
OverloadUntilIsNil
(),
dbaccount
.
OverloadUntilLTE
(
now
)),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
now
)),
)
...
...
@@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account {
})
}
func
notExpiredPredicate
(
now
time
.
Time
)
dbpredicate
.
Account
{
return
dbaccount
.
Or
(
dbaccount
.
ExpiresAtIsNil
(),
dbaccount
.
ExpiresAtGT
(
now
),
dbaccount
.
AutoPauseOnExpiredEQ
(
false
),
)
}
func
(
r
*
accountRepository
)
loadTempUnschedStates
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
tempUnschedSnapshot
,
error
)
{
out
:=
make
(
map
[
int64
]
tempUnschedSnapshot
)
if
len
(
accountIDs
)
==
0
{
...
...
@@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
Status
:
m
.
Status
,
ErrorMessage
:
derefString
(
m
.
ErrorMessage
),
LastUsedAt
:
m
.
LastUsedAt
,
ExpiresAt
:
m
.
ExpiresAt
,
AutoPauseOnExpired
:
m
.
AutoPauseOnExpired
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
Schedulable
:
m
.
Schedulable
,
...
...
backend/internal/repository/github_release_service.go
View file @
c7abfe67
...
...
@@ -14,23 +14,33 @@ import (
)
type
githubReleaseClient
struct
{
httpClient
*
http
.
Client
allowPrivateHosts
bool
httpClient
*
http
.
Client
downloadHTTPClient
*
http
.
Client
}
func
NewGitHubReleaseClient
()
service
.
GitHubReleaseClient
{
allowPrivate
:=
false
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
func
NewGitHubReleaseClient
(
proxyURL
string
)
service
.
GitHubReleaseClient
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
allowPrivate
,
Timeout
:
30
*
time
.
Second
,
ProxyURL
:
proxyURL
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
}
// 下载客户端需要更长的超时时间
downloadClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
10
*
time
.
Minute
,
ProxyURL
:
proxyURL
,
})
if
err
!=
nil
{
downloadClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Minute
}
}
return
&
githubReleaseClient
{
httpClient
:
sharedClient
,
allowPrivateHosts
:
allowPrivate
,
httpClient
:
sharedClient
,
downloadHTTPClient
:
downloadClient
,
}
}
...
...
@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return
err
}
downloadClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
10
*
time
.
Minute
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
c
.
allowPrivateHosts
,
})
if
err
!=
nil
{
downloadClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Minute
}
}
resp
,
err
:=
downloadClient
.
Do
(
req
)
// 使用预配置的下载客户端(已包含代理配置)
resp
,
err
:=
c
.
downloadHTTPClient
.
Do
(
req
)
if
err
!=
nil
{
return
err
}
...
...
backend/internal/repository/github_release_service_test.go
View file @
c7abfe67
...
...
@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func
newTestGitHubReleaseClient
()
*
githubReleaseClient
{
return
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{},
allowPrivateHosts
:
true
,
httpClient
:
&
http
.
Client
{},
downloadHTTPClient
:
&
http
.
Client
{}
,
}
}
...
...
@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
downloadHTTPClient
:
&
http
.
Client
{}
,
}
release
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
...
...
@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
downloadHTTPClient
:
&
http
.
Client
{}
,
}
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
...
...
@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
downloadHTTPClient
:
&
http
.
Client
{}
,
}
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
...
...
@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
downloadHTTPClient
:
&
http
.
Client
{}
,
}
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
...
...
backend/internal/repository/pricing_service.go
View file @
c7abfe67
...
...
@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
...
...
@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
httpClient
*
http
.
Client
}
func
NewPricingRemoteClient
(
cfg
*
config
.
Config
)
service
.
PricingRemoteClient
{
allowPrivate
:=
false
validateResolvedIP
:=
true
if
cfg
!=
nil
{
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
validateResolvedIP
=
cfg
.
Security
.
URLAllowlist
.
Enabled
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func
NewPricingRemoteClient
(
proxyURL
string
)
service
.
PricingRemoteClient
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
validateResolvedIP
,
AllowPrivateHosts
:
allowPrivate
,
Timeout
:
30
*
time
.
Second
,
ProxyURL
:
proxyURL
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
backend/internal/repository/pricing_service_test.go
View file @
c7abfe67
...
...
@@ -6,7 +6,6 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
...
...
@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
func
(
s
*
PricingServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
client
,
ok
:=
NewPricingRemoteClient
(
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
AllowPrivateHosts
:
true
,
},
},
})
.
(
*
pricingRemoteClient
)
client
,
ok
:=
NewPricingRemoteClient
(
""
)
.
(
*
pricingRemoteClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
}
...
...
backend/internal/repository/proxy_probe_service.go
View file @
c7abfe67
...
...
@@ -24,7 +24,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
validateResolvedIP
=
cfg
.
Security
.
URLAllowlist
.
Enabled
}
if
insecure
{
log
.
Printf
(
"[ProxyProbe] Warning:
TLS verification is disabled for proxy probing
."
)
log
.
Printf
(
"[ProxyProbe] Warning:
insecure_skip_verify is not allowed and will cause probe failure
."
)
}
return
&
proxyProbeService
{
ipInfoURL
:
defaultIPInfoURL
,
...
...
backend/internal/repository/usage_log_repo.go
View file @
c7abfe67
...
...
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms,
user_agent,
image_count, image_size, created_at"
type
usageLogRepository
struct
{
client
*
dbent
.
Client
...
...
@@ -109,6 +109,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream,
duration_ms,
first_token_ms,
user_agent,
image_count,
image_size,
created_at
...
...
@@ -118,8 +119,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24,
$25, $26, $27
$20, $21, $22, $23, $24, $25, $26, $27, $28
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
...
...
@@ -129,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID
:=
nullInt64
(
log
.
SubscriptionID
)
duration
:=
nullInt
(
log
.
DurationMs
)
firstToken
:=
nullInt
(
log
.
FirstTokenMs
)
userAgent
:=
nullString
(
log
.
UserAgent
)
imageSize
:=
nullString
(
log
.
ImageSize
)
var
requestIDArg
any
...
...
@@ -161,6 +162,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log
.
Stream
,
duration
,
firstToken
,
userAgent
,
log
.
ImageCount
,
imageSize
,
createdAt
,
...
...
@@ -1388,6 +1390,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
return
stats
,
nil
}
// GetStatsWithFilters gets usage statistics with optional filters
func
(
r
*
usageLogRepository
)
GetStatsWithFilters
(
ctx
context
.
Context
,
filters
UsageLogFilters
)
(
*
UsageStats
,
error
)
{
conditions
:=
make
([]
string
,
0
,
9
)
args
:=
make
([]
any
,
0
,
9
)
if
filters
.
UserID
>
0
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"user_id = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
filters
.
UserID
)
}
if
filters
.
APIKeyID
>
0
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"api_key_id = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
filters
.
APIKeyID
)
}
if
filters
.
AccountID
>
0
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"account_id = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
filters
.
AccountID
)
}
if
filters
.
GroupID
>
0
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"group_id = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
filters
.
GroupID
)
}
if
filters
.
Model
!=
""
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"model = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
filters
.
Model
)
}
if
filters
.
Stream
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"stream = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
*
filters
.
Stream
)
}
if
filters
.
BillingType
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"billing_type = $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
int16
(
*
filters
.
BillingType
))
}
if
filters
.
StartTime
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at >= $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
*
filters
.
StartTime
)
}
if
filters
.
EndTime
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at <= $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
*
filters
.
EndTime
)
}
query
:=
fmt
.
Sprintf
(
`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`
,
buildWhere
(
conditions
))
stats
:=
&
UsageStats
{}
if
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
query
,
args
,
&
stats
.
TotalRequests
,
&
stats
.
TotalInputTokens
,
&
stats
.
TotalOutputTokens
,
&
stats
.
TotalCacheTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
stats
.
AverageDurationMs
,
);
err
!=
nil
{
return
nil
,
err
}
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
return
stats
,
nil
}
// AccountUsageHistory represents daily usage history for an account
type
AccountUsageHistory
=
usagestats
.
AccountUsageHistory
...
...
@@ -1795,6 +1872,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream
bool
durationMs
sql
.
NullInt64
firstTokenMs
sql
.
NullInt64
userAgent
sql
.
NullString
imageCount
int
imageSize
sql
.
NullString
createdAt
time
.
Time
...
...
@@ -1826,6 +1904,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
stream
,
&
durationMs
,
&
firstTokenMs
,
&
userAgent
,
&
imageCount
,
&
imageSize
,
&
createdAt
,
...
...
@@ -1877,6 +1956,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value
:=
int
(
firstTokenMs
.
Int64
)
log
.
FirstTokenMs
=
&
value
}
if
userAgent
.
Valid
{
log
.
UserAgent
=
&
userAgent
.
String
}
if
imageSize
.
Valid
{
log
.
ImageSize
=
&
imageSize
.
String
}
...
...
backend/internal/repository/wire.go
View file @
c7abfe67
...
...
@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
return
NewConcurrencyCache
(
rdb
,
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
,
waitTTLSeconds
)
}
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func
ProvideGitHubReleaseClient
(
cfg
*
config
.
Config
)
service
.
GitHubReleaseClient
{
return
NewGitHubReleaseClient
(
cfg
.
Update
.
ProxyURL
)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func
ProvidePricingRemoteClient
(
cfg
*
config
.
Config
)
service
.
PricingRemoteClient
{
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
}
// ProviderSet is the Wire provider set for all repositories
var
ProviderSet
=
wire
.
NewSet
(
NewUserRepository
,
...
...
@@ -53,8 +65,8 @@ var ProviderSet = wire.NewSet(
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier
,
New
PricingRemoteClient
,
New
GitHubReleaseClient
,
Provide
PricingRemoteClient
,
Provide
GitHubReleaseClient
,
NewProxyExitInfoProber
,
NewClaudeUsageFetcher
,
NewClaudeOAuthClient
,
...
...
backend/internal/server/api_contract_test.go
View file @
c7abfe67
...
...
@@ -1065,6 +1065,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetStatsWithFilters
(
ctx
context
.
Context
,
filters
usagestats
.
UsageLogFilters
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubSettingRepo
struct
{
all
map
[
string
]
string
}
...
...
backend/internal/service/account.go
View file @
c7abfe67
...
...
@@ -9,21 +9,23 @@ import (
)
type
Account
struct
{
ID
int64
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
ProxyID
*
int64
Concurrency
int
Priority
int
Status
string
ErrorMessage
string
LastUsedAt
*
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
ID
int64
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
ProxyID
*
int64
Concurrency
int
Priority
int
Status
string
ErrorMessage
string
LastUsedAt
*
time
.
Time
ExpiresAt
*
time
.
Time
AutoPauseOnExpired
bool
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
Schedulable
bool
...
...
@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
return
false
}
now
:=
time
.
Now
()
if
a
.
AutoPauseOnExpired
&&
a
.
ExpiresAt
!=
nil
&&
!
now
.
Before
(
*
a
.
ExpiresAt
)
{
return
false
}
if
a
.
OverloadUntil
!=
nil
&&
now
.
Before
(
*
a
.
OverloadUntil
)
{
return
false
}
...
...
backend/internal/service/account_expiry_service.go
0 → 100644
View file @
c7abfe67
package
service
import
(
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type
AccountExpiryService
struct
{
accountRepo
AccountRepository
interval
time
.
Duration
stopCh
chan
struct
{}
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
}
func
NewAccountExpiryService
(
accountRepo
AccountRepository
,
interval
time
.
Duration
)
*
AccountExpiryService
{
return
&
AccountExpiryService
{
accountRepo
:
accountRepo
,
interval
:
interval
,
stopCh
:
make
(
chan
struct
{}),
}
}
func
(
s
*
AccountExpiryService
)
Start
()
{
if
s
==
nil
||
s
.
accountRepo
==
nil
||
s
.
interval
<=
0
{
return
}
s
.
wg
.
Add
(
1
)
go
func
()
{
defer
s
.
wg
.
Done
()
ticker
:=
time
.
NewTicker
(
s
.
interval
)
defer
ticker
.
Stop
()
s
.
runOnce
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
runOnce
()
case
<-
s
.
stopCh
:
return
}
}
}()
}
func
(
s
*
AccountExpiryService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
s
.
wg
.
Wait
()
}
func
(
s
*
AccountExpiryService
)
runOnce
()
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
updated
,
err
:=
s
.
accountRepo
.
AutoPauseExpiredAccounts
(
ctx
,
time
.
Now
())
if
err
!=
nil
{
log
.
Printf
(
"[AccountExpiry] Auto pause expired accounts failed: %v"
,
err
)
return
}
if
updated
>
0
{
log
.
Printf
(
"[AccountExpiry] Auto paused %d expired accounts"
,
updated
)
}
}
backend/internal/service/account_service.go
View file @
c7abfe67
...
...
@@ -38,6 +38,7 @@ type AccountRepository interface {
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
AutoPauseExpiredAccounts
(
ctx
context
.
Context
,
now
time
.
Time
)
(
int64
,
error
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
ListSchedulable
(
ctx
context
.
Context
)
([]
Account
,
error
)
...
...
@@ -71,29 +72,33 @@ type AccountBulkUpdate struct {
// CreateAccountRequest 创建账号请求
type
CreateAccountRequest
struct
{
Name
string
`json:"name"`
Notes
*
string
`json:"notes"`
Platform
string
`json:"platform"`
Type
string
`json:"type"`
Credentials
map
[
string
]
any
`json:"credentials"`
Extra
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
int
`json:"concurrency"`
Priority
int
`json:"priority"`
GroupIDs
[]
int64
`json:"group_ids"`
Name
string
`json:"name"`
Notes
*
string
`json:"notes"`
Platform
string
`json:"platform"`
Type
string
`json:"type"`
Credentials
map
[
string
]
any
`json:"credentials"`
Extra
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
int
`json:"concurrency"`
Priority
int
`json:"priority"`
GroupIDs
[]
int64
`json:"group_ids"`
ExpiresAt
*
time
.
Time
`json:"expires_at"`
AutoPauseOnExpired
*
bool
`json:"auto_pause_on_expired"`
}
// UpdateAccountRequest 更新账号请求
type
UpdateAccountRequest
struct
{
Name
*
string
`json:"name"`
Notes
*
string
`json:"notes"`
Credentials
*
map
[
string
]
any
`json:"credentials"`
Extra
*
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
*
int
`json:"concurrency"`
Priority
*
int
`json:"priority"`
Status
*
string
`json:"status"`
GroupIDs
*
[]
int64
`json:"group_ids"`
Name
*
string
`json:"name"`
Notes
*
string
`json:"notes"`
Credentials
*
map
[
string
]
any
`json:"credentials"`
Extra
*
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
*
int
`json:"concurrency"`
Priority
*
int
`json:"priority"`
Status
*
string
`json:"status"`
GroupIDs
*
[]
int64
`json:"group_ids"`
ExpiresAt
*
time
.
Time
`json:"expires_at"`
AutoPauseOnExpired
*
bool
`json:"auto_pause_on_expired"`
}
// AccountService 账号管理服务
...
...
@@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
Concurrency
:
req
.
Concurrency
,
Priority
:
req
.
Priority
,
Status
:
StatusActive
,
ExpiresAt
:
req
.
ExpiresAt
,
}
if
req
.
AutoPauseOnExpired
!=
nil
{
account
.
AutoPauseOnExpired
=
*
req
.
AutoPauseOnExpired
}
else
{
account
.
AutoPauseOnExpired
=
true
}
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
...
...
@@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if
req
.
Status
!=
nil
{
account
.
Status
=
*
req
.
Status
}
if
req
.
ExpiresAt
!=
nil
{
account
.
ExpiresAt
=
req
.
ExpiresAt
}
if
req
.
AutoPauseOnExpired
!=
nil
{
account
.
AutoPauseOnExpired
=
*
req
.
AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if
req
.
GroupIDs
!=
nil
{
...
...
backend/internal/service/account_service_delete_test.go
View file @
c7abfe67
...
...
@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
panic
(
"unexpected SetSchedulable call"
)
}
func
(
s
*
accountRepoStub
)
AutoPauseExpiredAccounts
(
ctx
context
.
Context
,
now
time
.
Time
)
(
int64
,
error
)
{
panic
(
"unexpected AutoPauseExpiredAccounts call"
)
}
func
(
s
*
accountRepoStub
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
panic
(
"unexpected BindGroups call"
)
}
...
...
backend/internal/service/account_usage_service.go
View file @
c7abfe67
...
...
@@ -47,6 +47,7 @@ type UsageLogRepository interface {
// Admin usage listing/stats
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
GetGlobalStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
GetStatsWithFilters
(
ctx
context
.
Context
,
filters
usagestats
.
UsageLogFilters
)
(
*
usagestats
.
UsageStats
,
error
)
// Account stats
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
AccountUsageStatsResponse
,
error
)
...
...
backend/internal/service/admin_service.go
View file @
c7abfe67
...
...
@@ -122,16 +122,18 @@ type UpdateGroupInput struct {
}
type
CreateAccountInput
struct
{
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
ProxyID
*
int64
Concurrency
int
Priority
int
GroupIDs
[]
int64
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
ProxyID
*
int64
Concurrency
int
Priority
int
GroupIDs
[]
int64
ExpiresAt
*
int64
AutoPauseOnExpired
*
bool
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck
bool
...
...
@@ -148,6 +150,8 @@ type UpdateAccountInput struct {
Priority
*
int
// 使用指针区分"未提供"和"设置为0"
Status
string
GroupIDs
*
[]
int64
ExpiresAt
*
int64
AutoPauseOnExpired
*
bool
SkipMixedChannelCheck
bool
// 跳过混合渠道检查(用户已确认风险)
}
...
...
@@ -700,6 +704,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status
:
StatusActive
,
Schedulable
:
true
,
}
if
input
.
ExpiresAt
!=
nil
&&
*
input
.
ExpiresAt
>
0
{
expiresAt
:=
time
.
Unix
(
*
input
.
ExpiresAt
,
0
)
account
.
ExpiresAt
=
&
expiresAt
}
if
input
.
AutoPauseOnExpired
!=
nil
{
account
.
AutoPauseOnExpired
=
*
input
.
AutoPauseOnExpired
}
else
{
account
.
AutoPauseOnExpired
=
true
}
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -755,6 +768,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if
input
.
Status
!=
""
{
account
.
Status
=
input
.
Status
}
if
input
.
ExpiresAt
!=
nil
{
if
*
input
.
ExpiresAt
<=
0
{
account
.
ExpiresAt
=
nil
}
else
{
expiresAt
:=
time
.
Unix
(
*
input
.
ExpiresAt
,
0
)
account
.
ExpiresAt
=
&
expiresAt
}
}
if
input
.
AutoPauseOnExpired
!=
nil
{
account
.
AutoPauseOnExpired
=
*
input
.
AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if
input
.
GroupIDs
!=
nil
{
...
...
backend/internal/service/auth_service.go
View file @
c7abfe67
...
...
@@ -20,12 +20,16 @@ var (
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const
maxTokenLength
=
8192
// JWTClaims JWT载荷数据
type
JWTClaims
struct
{
UserID
int64
`json:"user_id"`
...
...
@@ -309,7 +313,20 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// ValidateToken 验证JWT token并返回用户声明
func
(
s
*
AuthService
)
ValidateToken
(
tokenString
string
)
(
*
JWTClaims
,
error
)
{
token
,
err
:=
jwt
.
ParseWithClaims
(
tokenString
,
&
JWTClaims
{},
func
(
token
*
jwt
.
Token
)
(
any
,
error
)
{
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
if
len
(
tokenString
)
>
maxTokenLength
{
return
nil
,
ErrTokenTooLarge
}
// 使用解析器并限制可接受的签名算法,防止算法混淆。
parser
:=
jwt
.
NewParser
(
jwt
.
WithValidMethods
([]
string
{
jwt
.
SigningMethodHS256
.
Name
,
jwt
.
SigningMethodHS384
.
Name
,
jwt
.
SigningMethodHS512
.
Name
,
}))
// 保留默认 claims 校验(exp/nbf),避免放行过期或未生效的 token。
token
,
err
:=
parser
.
ParseWithClaims
(
tokenString
,
&
JWTClaims
{},
func
(
token
*
jwt
.
Token
)
(
any
,
error
)
{
// 验证签名方法
if
_
,
ok
:=
token
.
Method
.
(
*
jwt
.
SigningMethodHMAC
);
!
ok
{
return
nil
,
fmt
.
Errorf
(
"unexpected signing method: %v"
,
token
.
Header
[
"alg"
])
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment