"frontend/vscode:/vscode.git/clone" did not exist on "9abda1bc5917dfb9a4f2eba5fed509fb2e416ac9"
Commit c7abfe67 authored by song's avatar song
Browse files

Merge remote-tracking branch 'upstream/main'

parents 4e3476a6 db6f53e2
...@@ -242,7 +242,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -242,7 +242,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
// Async record usage // Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) { go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
...@@ -251,10 +251,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -251,10 +251,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account) }(result, account, userAgent)
return return
} }
} }
......
...@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse date range // Parse date range
var startTime, endTime *time.Time var startTime, endTime *time.Time
userTZ := c.Query("timezone") // Get user's timezone from request
if startDateStr := c.Query("start_date"); startDateStr != "" { if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr) t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return return
...@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
if endDateStr := c.Query("end_date"); endDateStr != "" { if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr) t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return return
...@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// 获取时间范围参数 // 获取时间范围参数
now := timezone.Now() userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
var startTime, endTime time.Time var startTime, endTime time.Time
// 优先使用 start_date 和 end_date 参数 // 优先使用 start_date 和 end_date 参数
...@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if startDateStr != "" && endDateStr != "" { if startDateStr != "" && endDateStr != "" {
// 使用自定义日期范围 // 使用自定义日期范围
var err error var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr) startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return return
} }
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr) endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return return
...@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
period := c.DefaultQuery("period", "today") period := c.DefaultQuery("period", "today")
switch period { switch period {
case "today": case "today":
startTime = timezone.StartOfDay(now) startTime = timezone.StartOfDayInUserLocation(now, userTZ)
case "week": case "week":
startTime = now.AddDate(0, 0, -7) startTime = now.AddDate(0, 0, -7)
case "month": case "month":
startTime = now.AddDate(0, -1, 0) startTime = now.AddDate(0, -1, 0)
default: default:
startTime = timezone.StartOfDay(now) startTime = timezone.StartOfDayInUserLocation(now, userTZ)
} }
endTime = now endTime = now
} }
...@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard // parseUserTimeRange parses start_date, end_date query parameters for user dashboard
// Uses user's timezone if provided, otherwise falls back to server timezone
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now() userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
startDate := c.Query("start_date") startDate := c.Query("start_date")
endDate := c.Query("end_date") endDate := c.Query("end_date")
var startTime, endTime time.Time var startTime, endTime time.Time
if startDate != "" { if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil { if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
startTime = t startTime = t
} else { } else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7)) startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
} }
} else { } else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7)) startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
} }
if endDate != "" { if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil { if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date endTime = t.Add(24 * time.Hour) // Include the end date
} else { } else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1)) endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
} }
} else { } else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1)) endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
} }
return startTime, endTime return startTime, endTime
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
package httpclient package httpclient
import ( import (
"crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
...@@ -40,7 +39,7 @@ type Options struct { ...@@ -40,7 +39,7 @@ type Options struct {
ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h) ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h)
Timeout time.Duration // 请求总超时时间 Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true)
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
...@@ -113,7 +112,8 @@ func buildTransport(opts Options) (*http.Transport, error) { ...@@ -113,7 +112,8 @@ func buildTransport(opts Options) (*http.Transport, error) {
} }
if opts.InsecureSkipVerify { if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // 安全要求:禁止跳过证书验证,避免中间人攻击。
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
} }
proxyURL := strings.TrimSpace(opts.ProxyURL) proxyURL := strings.TrimSpace(opts.ProxyURL)
......
...@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time { ...@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time {
func ParseInLocation(layout, value string) (time.Time, error) { func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location()) return time.ParseInLocation(layout, value, Location())
} }
// ParseInUserLocation parses a time string in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func ParseInUserLocation(layout, value, userTZ string) (time.Time, error) {
loc := Location() // default to server timezone
if userTZ != "" {
if userLoc, err := time.LoadLocation(userTZ); err == nil {
loc = userLoc
}
}
return time.ParseInLocation(layout, value, loc)
}
// NowInUserLocation returns the current time in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func NowInUserLocation(userTZ string) time.Time {
if userTZ == "" {
return Now()
}
if userLoc, err := time.LoadLocation(userTZ); err == nil {
return time.Now().In(userLoc)
}
return Now()
}
// StartOfDayInUserLocation returns the start of the given day in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func StartOfDayInUserLocation(t time.Time, userTZ string) time.Time {
loc := Location()
if userTZ != "" {
if userLoc, err := time.LoadLocation(userTZ); err == nil {
loc = userLoc
}
}
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
...@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetPriority(account.Priority). SetPriority(account.Priority).
SetStatus(account.Status). SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage). SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable) SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.ProxyID != nil { if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID) builder.SetProxyID(*account.ProxyID)
...@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
if account.LastUsedAt != nil { if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt) builder.SetLastUsedAt(*account.LastUsedAt)
} }
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
}
if account.RateLimitedAt != nil { if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt) builder.SetRateLimitedAt(*account.RateLimitedAt)
} }
...@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetPriority(account.Priority). SetPriority(account.Priority).
SetStatus(account.Status). SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage). SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable) SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.ProxyID != nil { if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID) builder.SetProxyID(*account.ProxyID)
...@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} else { } else {
builder.ClearLastUsedAt() builder.ClearLastUsedAt()
} }
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
if account.RateLimitedAt != nil { if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt) builder.SetRateLimitedAt(*account.RateLimitedAt)
} else { } else {
...@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco ...@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
dbaccount.StatusEQ(service.StatusActive), dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true), dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(), tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
). ).
...@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf ...@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
dbaccount.StatusEQ(service.StatusActive), dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true), dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(), tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
). ).
...@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat ...@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
dbaccount.StatusEQ(service.StatusActive), dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true), dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(), tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
). ).
...@@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu ...@@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err return err
} }
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET schedulable = FALSE,
updated_at = NOW()
WHERE deleted_at IS NULL
AND schedulable = TRUE
AND auto_pause_on_expired = TRUE
AND expires_at IS NOT NULL
AND expires_at <= $1
`, now)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 { if len(updates) == 0 {
return nil return nil
...@@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in ...@@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
preds = append(preds, preds = append(preds,
dbaccount.SchedulableEQ(true), dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(), tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
) )
...@@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account { ...@@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account {
}) })
} }
func notExpiredPredicate(now time.Time) dbpredicate.Account {
return dbaccount.Or(
dbaccount.ExpiresAtIsNil(),
dbaccount.ExpiresAtGT(now),
dbaccount.AutoPauseOnExpiredEQ(false),
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot) out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 { if len(accountIDs) == 0 {
...@@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account { ...@@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
Status: m.Status, Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage), ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt, LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt, UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable, Schedulable: m.Schedulable,
......
...@@ -15,22 +15,32 @@ import ( ...@@ -15,22 +15,32 @@ import (
type githubReleaseClient struct { type githubReleaseClient struct {
httpClient *http.Client httpClient *http.Client
allowPrivateHosts bool downloadHTTPClient *http.Client
} }
func NewGitHubReleaseClient() service.GitHubReleaseClient { // NewGitHubReleaseClient 创建 GitHub Release 客户端
allowPrivate := false // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true, ProxyURL: proxyURL,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
} }
// 下载客户端需要更长的超时时间
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ProxyURL: proxyURL,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: sharedClient, httpClient: sharedClient,
allowPrivateHosts: allowPrivate, downloadHTTPClient: downloadClient,
} }
} }
...@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err return err
} }
downloadClient, err := httpclient.GetClient(httpclient.Options{ // 使用预配置的下载客户端(已包含代理配置)
Timeout: 10 * time.Minute, resp, err := c.downloadHTTPClient.Do(req)
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -40,7 +40,7 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -40,7 +40,7 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func newTestGitHubReleaseClient() *githubReleaseClient { func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: &http.Client{}, httpClient: &http.Client{},
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
} }
...@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ...@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo") release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { ...@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { ...@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { ...@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -17,17 +16,12 @@ type pricingRemoteClient struct { ...@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
httpClient *http.Client httpClient *http.Client
} }
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { // NewPricingRemoteClient 创建定价数据远程客户端
allowPrivate := false // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
validateResolvedIP := true func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP, ProxyURL: proxyURL,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
...@@ -20,13 +19,7 @@ type PricingServiceSuite struct { ...@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() { func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, ok := NewPricingRemoteClient(&config.Config{ client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
} }
......
...@@ -24,7 +24,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { ...@@ -24,7 +24,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
validateResolvedIP = cfg.Security.URLAllowlist.Enabled validateResolvedIP = cfg.Security.URLAllowlist.Enabled
} }
if insecure { if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.") log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
} }
return &proxyProbeService{ return &proxyProbeService{
ipInfoURL: defaultIPInfoURL, ipInfoURL: defaultIPInfoURL,
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -109,6 +109,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -109,6 +109,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream, stream,
duration_ms, duration_ms,
first_token_ms, first_token_ms,
user_agent,
image_count, image_count,
image_size, image_size,
created_at created_at
...@@ -118,8 +119,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -118,8 +119,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $20, $21, $22, $23, $24, $25, $26, $27, $28
$25, $26, $27
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -129,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -129,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID := nullInt64(log.SubscriptionID) subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
var requestIDArg any var requestIDArg any
...@@ -161,6 +162,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -161,6 +162,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.Stream, log.Stream,
duration, duration,
firstToken, firstToken,
userAgent,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
createdAt, createdAt,
...@@ -1388,6 +1390,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT ...@@ -1388,6 +1390,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
return stats, nil return stats, nil
} }
// GetStatsWithFilters gets usage statistics with optional filters
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
args = append(args, filters.AccountID)
}
if filters.GroupID > 0 {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
args,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
// AccountUsageHistory represents daily usage history for an account // AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory type AccountUsageHistory = usagestats.AccountUsageHistory
...@@ -1795,6 +1872,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1795,6 +1872,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream bool stream bool
durationMs sql.NullInt64 durationMs sql.NullInt64
firstTokenMs sql.NullInt64 firstTokenMs sql.NullInt64
userAgent sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
createdAt time.Time createdAt time.Time
...@@ -1826,6 +1904,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1826,6 +1904,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&stream, &stream,
&durationMs, &durationMs,
&firstTokenMs, &firstTokenMs,
&userAgent,
&imageCount, &imageCount,
&imageSize, &imageSize,
&createdAt, &createdAt,
...@@ -1877,6 +1956,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1877,6 +1956,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value := int(firstTokenMs.Int64) value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value log.FirstTokenMs = &value
} }
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
......
...@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc ...@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds) return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
} }
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
...@@ -53,8 +65,8 @@ var ProviderSet = wire.NewSet( ...@@ -53,8 +65,8 @@ var ProviderSet = wire.NewSet(
// HTTP service ports (DI Strategy A: return interface directly) // HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier, NewTurnstileVerifier,
NewPricingRemoteClient, ProvidePricingRemoteClient,
NewGitHubReleaseClient, ProvideGitHubReleaseClient,
NewProxyExitInfoProber, NewProxyExitInfoProber,
NewClaudeUsageFetcher, NewClaudeUsageFetcher,
NewClaudeOAuthClient, NewClaudeOAuthClient,
......
...@@ -1065,6 +1065,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i ...@@ -1065,6 +1065,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct { type stubSettingRepo struct {
all map[string]string all map[string]string
} }
......
...@@ -22,6 +22,8 @@ type Account struct { ...@@ -22,6 +22,8 @@ type Account struct {
Status string Status string
ErrorMessage string ErrorMessage string
LastUsedAt *time.Time LastUsedAt *time.Time
ExpiresAt *time.Time
AutoPauseOnExpired bool
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
...@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool { ...@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
return false return false
} }
now := time.Now() now := time.Now()
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
return false
}
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) { if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false return false
} }
......
package service
import (
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type AccountExpiryService struct {
accountRepo AccountRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
return &AccountExpiryService{
accountRepo: accountRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *AccountExpiryService) Start() {
if s == nil || s.accountRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *AccountExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *AccountExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
if err != nil {
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
return
}
if updated > 0 {
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
}
}
...@@ -38,6 +38,7 @@ type AccountRepository interface { ...@@ -38,6 +38,7 @@ type AccountRepository interface {
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error) ListSchedulable(ctx context.Context) ([]Account, error)
...@@ -81,6 +82,8 @@ type CreateAccountRequest struct { ...@@ -81,6 +82,8 @@ type CreateAccountRequest struct {
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
Priority int `json:"priority"` Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
} }
// UpdateAccountRequest 更新账号请求 // UpdateAccountRequest 更新账号请求
...@@ -94,6 +97,8 @@ type UpdateAccountRequest struct { ...@@ -94,6 +97,8 @@ type UpdateAccountRequest struct {
Priority *int `json:"priority"` Priority *int `json:"priority"`
Status *string `json:"status"` Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
} }
// AccountService 账号管理服务 // AccountService 账号管理服务
...@@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Priority: req.Priority, Priority: req.Priority,
Status: StatusActive, Status: StatusActive,
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
...@@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount ...@@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Status != nil { if req.Status != nil {
account.Status = *req.Status account.Status = *req.Status
} }
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前) // 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil { if req.GroupIDs != nil {
......
...@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula ...@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
panic("unexpected SetSchedulable call") panic("unexpected SetSchedulable call")
} }
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
panic("unexpected AutoPauseExpiredAccounts call")
}
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
panic("unexpected BindGroups call") panic("unexpected BindGroups call")
} }
......
...@@ -47,6 +47,7 @@ type UsageLogRepository interface { ...@@ -47,6 +47,7 @@ type UsageLogRepository interface {
// Admin usage listing/stats // Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error)
// Account stats // Account stats
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
......
...@@ -132,6 +132,8 @@ type CreateAccountInput struct { ...@@ -132,6 +132,8 @@ type CreateAccountInput struct {
Concurrency int Concurrency int
Priority int Priority int
GroupIDs []int64 GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk. // This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool SkipMixedChannelCheck bool
...@@ -148,6 +150,8 @@ type UpdateAccountInput struct { ...@@ -148,6 +150,8 @@ type UpdateAccountInput struct {
Priority *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
ExpiresAt *int64
AutoPauseOnExpired *bool
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险) SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
} }
...@@ -700,6 +704,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -700,6 +704,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status: StatusActive, Status: StatusActive,
Schedulable: true, Schedulable: true,
} }
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err return nil, err
} }
...@@ -755,6 +768,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -755,6 +768,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Status != "" { if input.Status != "" {
account.Status = input.Status account.Status = input.Status
} }
if input.ExpiresAt != nil {
if *input.ExpiresAt <= 0 {
account.ExpiresAt = nil
} else {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前) // 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil { if input.GroupIDs != nil {
......
...@@ -20,12 +20,16 @@ var ( ...@@ -20,12 +20,16 @@ var (
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
) )
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
type JWTClaims struct { type JWTClaims struct {
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
...@@ -309,7 +313,20 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -309,7 +313,20 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
if len(tokenString) > maxTokenLength {
return nil, ErrTokenTooLarge
}
// 使用解析器并限制可接受的签名算法,防止算法混淆。
parser := jwt.NewParser(jwt.WithValidMethods([]string{
jwt.SigningMethodHS256.Name,
jwt.SigningMethodHS384.Name,
jwt.SigningMethodHS512.Name,
}))
// 保留默认 claims 校验(exp/nbf),避免放行过期或未生效的 token。
token, err := parser.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
// 验证签名方法 // 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment