Unverified Commit 2fe8932c authored by Call White's avatar Call White Committed by GitHub
Browse files

Merge pull request #3 from cyhhao/main

merge to main
parents 2f2e76f9 adb77af1
package service
import "context"
import (
"context"
"log/slog"
"strconv"
)
type TokenCacheInvalidator interface {
InvalidateToken(ctx context.Context, account *Account) error
......@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
return nil
}
var cacheKey string
var keysToDelete []string
accountIDKey := "account:" + strconv.FormatInt(account.ID, 10)
switch account.Platform {
case PlatformGemini:
cacheKey = GeminiTokenCacheKey(account)
// Gemini 可能有两种缓存键:project_id 或 account_id
// 首次获取 token 时可能没有 project_id,之后自动检测到 project_id 后会使用新 key
// 刷新时需要同时删除两种可能的 key,确保不会遗留旧缓存
keysToDelete = append(keysToDelete, GeminiTokenCacheKey(account))
keysToDelete = append(keysToDelete, "gemini:"+accountIDKey)
case PlatformAntigravity:
cacheKey = AntigravityTokenCacheKey(account)
// Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformOpenAI:
cacheKey = OpenAITokenCacheKey(account)
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic:
cacheKey = ClaudeTokenCacheKey(account)
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
default:
return nil
}
return c.cache.DeleteAccessToken(ctx, cacheKey)
// 删除所有可能的缓存键(去重后)
seen := make(map[string]bool)
for _, key := range keysToDelete {
if seen[key] {
continue
}
seen[key] = true
if err := c.cache.DeleteAccessToken(ctx, key); err != nil {
slog.Warn("token_cache_delete_failed", "key", key, "account_id", account.ID, "error", err)
}
}
return nil
}
// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account
// 用于解决异步刷新任务与请求线程的竞态条件:
// 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存
//
// 返回值:
// - latestAccount: 从 DB 获取的最新 account(如果查询失败则返回 nil)
// - isStale: true 表示 token 已过时(应使用 latestAccount),false 表示可以使用当前 account
func CheckTokenVersion(ctx context.Context, account *Account, repo AccountRepository) (latestAccount *Account, isStale bool) {
if account == nil || repo == nil {
return nil, false
}
currentVersion := account.GetCredentialAsInt64("_token_version")
latestAccount, err := repo.GetByID(ctx, account.ID)
if err != nil || latestAccount == nil {
// 查询失败,默认允许缓存,不返回 latestAccount
return nil, false
}
latestVersion := latestAccount.GetCredentialAsInt64("_token_version")
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
// 说明异步刷新任务已更新 token,当前 account 已过时
if currentVersion == 0 && latestVersion > 0 {
slog.Debug("token_version_stale_no_current_version",
"account_id", account.ID,
"latest_version", latestVersion)
return latestAccount, true
}
// 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存
if currentVersion == 0 && latestVersion == 0 {
return latestAccount, false
}
// 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时
if latestVersion > currentVersion {
slog.Debug("token_version_stale",
"account_id", account.ID,
"current_version", currentVersion,
"latest_version", latestVersion)
return latestAccount, true
}
return latestAccount, false
}
......@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
// 这是为了处理:首次获取 token 时可能没有 project_id,之后自动检测到后会使用新 key
require.Equal(t, []string{"gemini:project-x", "gemini:account:10"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_GeminiWithoutProjectID(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 10,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "gemini-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require.Equal(t, []string{"gemini:account:10"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
......@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
require.Equal(t, []string{"ag:ag-project", "ag:account:99"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_AntigravityWithoutProjectID(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 99,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ag-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require.Equal(t, []string{"ag:account:99"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
......@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 新行为:删除失败只记录日志,不返回错误
// 这是因为缓存失效失败不应影响主业务流程
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.Error(t, err)
require.Equal(t, expectedErr, err)
require.NoError(t, err)
})
}
}
......@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
}
// 新行为:Gemini 和 Antigravity 会同时删除基于 project_id 和 account_id 的键
expectedKeys := []string{
"gemini:gemini-proj",
"gemini:account:1",
"ag:ag-proj",
"ag:account:2",
"openai:account:3",
"claude:account:4",
}
......@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
require.Equal(t, expectedKeys, cache.deletedKeys)
}
// ========== GetCredentialAsInt64 测试 ==========
func TestAccount_GetCredentialAsInt64(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
key string
expected int64
}{
{
name: "int64_value",
credentials: map[string]any{"_token_version": int64(1737654321000)},
key: "_token_version",
expected: 1737654321000,
},
{
name: "float64_value",
credentials: map[string]any{"_token_version": float64(1737654321000)},
key: "_token_version",
expected: 1737654321000,
},
{
name: "int_value",
credentials: map[string]any{"_token_version": 12345},
key: "_token_version",
expected: 12345,
},
{
name: "string_value",
credentials: map[string]any{"_token_version": "1737654321000"},
key: "_token_version",
expected: 1737654321000,
},
{
name: "string_with_spaces",
credentials: map[string]any{"_token_version": " 1737654321000 "},
key: "_token_version",
expected: 1737654321000,
},
{
name: "nil_credentials",
credentials: nil,
key: "_token_version",
expected: 0,
},
{
name: "missing_key",
credentials: map[string]any{"other_key": 123},
key: "_token_version",
expected: 0,
},
{
name: "nil_value",
credentials: map[string]any{"_token_version": nil},
key: "_token_version",
expected: 0,
},
{
name: "invalid_string",
credentials: map[string]any{"_token_version": "not_a_number"},
key: "_token_version",
expected: 0,
},
{
name: "empty_string",
credentials: map[string]any{"_token_version": ""},
key: "_token_version",
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Credentials: tt.credentials}
result := account.GetCredentialAsInt64(tt.key)
require.Equal(t, tt.expected, result)
})
}
}
func TestAccount_GetCredentialAsInt64_NilAccount(t *testing.T) {
var account *Account
result := account.GetCredentialAsInt64("_token_version")
require.Equal(t, int64(0), result)
}
// ========== CheckTokenVersion 测试 ==========
func TestCheckTokenVersion(t *testing.T) {
tests := []struct {
name string
account *Account
latestAccount *Account
repoErr error
expectedStale bool
}{
{
name: "nil_account",
account: nil,
latestAccount: nil,
expectedStale: false,
},
{
name: "no_version_in_account_but_db_has_version",
account: &Account{
ID: 1,
Credentials: map[string]any{},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
expectedStale: true, // 当前 account 无版本但 DB 有,说明已被异步刷新,当前已过时
},
{
name: "both_no_version",
account: &Account{
ID: 1,
Credentials: map[string]any{},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{},
},
expectedStale: false, // 两边都没有版本号,说明从未被异步刷新过,允许缓存
},
{
name: "same_version",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
expectedStale: false,
},
{
name: "current_version_newer",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(200)},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
expectedStale: false,
},
{
name: "current_version_older_stale",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(200)},
},
expectedStale: true, // 当前版本过时
},
{
name: "repo_error",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: nil,
repoErr: errors.New("db error"),
expectedStale: false, // 查询失败,默认允许缓存
},
{
name: "repo_returns_nil",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: nil,
repoErr: nil,
expectedStale: false, // 查询返回 nil,默认允许缓存
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 由于 CheckTokenVersion 接受 AccountRepository 接口,而创建完整的 mock 很繁琐
// 这里我们直接测试函数的核心逻辑来验证行为
if tt.name == "nil_account" {
_, isStale := CheckTokenVersion(context.Background(), nil, nil)
require.Equal(t, tt.expectedStale, isStale)
return
}
// 模拟 CheckTokenVersion 的核心逻辑
account := tt.account
currentVersion := account.GetCredentialAsInt64("_token_version")
// 模拟 repo 查询
latestAccount := tt.latestAccount
if tt.repoErr != nil || latestAccount == nil {
require.Equal(t, tt.expectedStale, false)
return
}
latestVersion := latestAccount.GetCredentialAsInt64("_token_version")
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
if currentVersion == 0 && latestVersion > 0 {
require.Equal(t, tt.expectedStale, true)
return
}
// 情况2: 两边都没有版本号
if currentVersion == 0 && latestVersion == 0 {
require.Equal(t, tt.expectedStale, false)
return
}
// 情况3: 比较版本号
isStale := latestVersion > currentVersion
require.Equal(t, tt.expectedStale, isStale)
})
}
}
func TestCheckTokenVersion_NilRepo(t *testing.T) {
account := &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
}
_, isStale := CheckTokenVersion(context.Background(), account, nil)
require.False(t, isStale) // nil repo,默认允许缓存
}
......@@ -166,11 +166,29 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功,更新账号credentials
// 如果有新凭证,先更新(即使有错误也要保存 token)
if newCredentials != nil {
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}
if err == nil {
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
if account.Platform == PlatformAntigravity &&
account.Status == StatusError &&
strings.Contains(account.ErrorMessage, "missing_project_id:") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
} else {
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
}
}
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
......@@ -219,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func isNonRetryableRefreshError(err error) bool {
if err == nil {
return false
......@@ -230,6 +249,7 @@ func isNonRetryableRefreshError(err error) bool {
"invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
......
package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log/slog"
"time"
"github.com/pquerna/otp/totp"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrTotpNotEnabled = infraerrors.BadRequest("TOTP_NOT_ENABLED", "totp feature is not enabled")
ErrTotpAlreadyEnabled = infraerrors.BadRequest("TOTP_ALREADY_ENABLED", "totp is already enabled for this account")
ErrTotpNotSetup = infraerrors.BadRequest("TOTP_NOT_SETUP", "totp is not set up for this account")
ErrTotpInvalidCode = infraerrors.BadRequest("TOTP_INVALID_CODE", "invalid totp code")
ErrTotpSetupExpired = infraerrors.BadRequest("TOTP_SETUP_EXPIRED", "totp setup session expired")
ErrTotpTooManyAttempts = infraerrors.TooManyRequests("TOTP_TOO_MANY_ATTEMPTS", "too many verification attempts, please try again later")
ErrVerifyCodeRequired = infraerrors.BadRequest("VERIFY_CODE_REQUIRED", "email verification code is required")
ErrPasswordRequired = infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
)
// TotpCache defines cache operations for TOTP service
type TotpCache interface {
// Setup session methods
GetSetupSession(ctx context.Context, userID int64) (*TotpSetupSession, error)
SetSetupSession(ctx context.Context, userID int64, session *TotpSetupSession, ttl time.Duration) error
DeleteSetupSession(ctx context.Context, userID int64) error
// Login session methods (for 2FA login flow)
GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error)
SetLoginSession(ctx context.Context, tempToken string, session *TotpLoginSession, ttl time.Duration) error
DeleteLoginSession(ctx context.Context, tempToken string) error
// Rate limiting
IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error)
GetVerifyAttempts(ctx context.Context, userID int64) (int, error)
ClearVerifyAttempts(ctx context.Context, userID int64) error
}
// SecretEncryptor defines encryption operations for TOTP secrets
type SecretEncryptor interface {
Encrypt(plaintext string) (string, error)
Decrypt(ciphertext string) (string, error)
}
// TotpSetupSession represents a TOTP setup session
type TotpSetupSession struct {
Secret string // Plain text TOTP secret (not encrypted yet)
SetupToken string // Random token to verify setup request
CreatedAt time.Time
}
// TotpLoginSession represents a pending 2FA login session
type TotpLoginSession struct {
UserID int64
Email string
TokenExpiry time.Time
}
// TotpStatus represents the TOTP status for a user
type TotpStatus struct {
Enabled bool `json:"enabled"`
EnabledAt *time.Time `json:"enabled_at,omitempty"`
FeatureEnabled bool `json:"feature_enabled"`
}
// TotpSetupResponse represents the response for initiating TOTP setup
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"` // seconds until setup expires
}
const (
totpSetupTTL = 5 * time.Minute
totpLoginTTL = 5 * time.Minute
totpAttemptsTTL = 15 * time.Minute
maxTotpAttempts = 5
totpIssuer = "Sub2API"
)
// TotpService handles TOTP operations
type TotpService struct {
userRepo UserRepository
encryptor SecretEncryptor
cache TotpCache
settingService *SettingService
emailService *EmailService
emailQueueService *EmailQueueService
}
// NewTotpService creates a new TOTP service
func NewTotpService(
userRepo UserRepository,
encryptor SecretEncryptor,
cache TotpCache,
settingService *SettingService,
emailService *EmailService,
emailQueueService *EmailQueueService,
) *TotpService {
return &TotpService{
userRepo: userRepo,
encryptor: encryptor,
cache: cache,
settingService: settingService,
emailService: emailService,
emailQueueService: emailQueueService,
}
}
// GetStatus returns the TOTP status for a user
func (s *TotpService) GetStatus(ctx context.Context, userID int64) (*TotpStatus, error) {
featureEnabled := s.settingService.IsTotpEnabled(ctx)
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
return &TotpStatus{
Enabled: user.TotpEnabled,
EnabledAt: user.TotpEnabledAt,
FeatureEnabled: featureEnabled,
}, nil
}
// InitiateSetup starts the TOTP setup process
// If email verification is enabled, emailCode is required; otherwise password is required
func (s *TotpService) InitiateSetup(ctx context.Context, userID int64, emailCode, password string) (*TotpSetupResponse, error) {
// Check if TOTP feature is enabled globally
if !s.settingService.IsTotpEnabled(ctx) {
return nil, ErrTotpNotEnabled
}
// Get user and check if TOTP is already enabled
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user.TotpEnabled {
return nil, ErrTotpAlreadyEnabled
}
// Verify identity based on email verification setting
if s.settingService.IsEmailVerifyEnabled(ctx) {
// Email verification enabled - verify email code
if emailCode == "" {
return nil, ErrVerifyCodeRequired
}
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
return nil, err
}
} else {
// Email verification disabled - verify password
if password == "" {
return nil, ErrPasswordRequired
}
if !user.CheckPassword(password) {
return nil, ErrPasswordIncorrect
}
}
// Generate a new TOTP key
key, err := totp.Generate(totp.GenerateOpts{
Issuer: totpIssuer,
AccountName: user.Email,
})
if err != nil {
return nil, fmt.Errorf("generate totp key: %w", err)
}
// Generate a random setup token
setupToken, err := generateRandomToken(32)
if err != nil {
return nil, fmt.Errorf("generate setup token: %w", err)
}
// Store the setup session in cache
session := &TotpSetupSession{
Secret: key.Secret(),
SetupToken: setupToken,
CreatedAt: time.Now(),
}
if err := s.cache.SetSetupSession(ctx, userID, session, totpSetupTTL); err != nil {
return nil, fmt.Errorf("store setup session: %w", err)
}
return &TotpSetupResponse{
Secret: key.Secret(),
QRCodeURL: key.URL(),
SetupToken: setupToken,
Countdown: int(totpSetupTTL.Seconds()),
}, nil
}
// CompleteSetup completes the TOTP setup by verifying the code
func (s *TotpService) CompleteSetup(ctx context.Context, userID int64, totpCode, setupToken string) error {
// Check if TOTP feature is enabled globally
if !s.settingService.IsTotpEnabled(ctx) {
return ErrTotpNotEnabled
}
// Get the setup session
session, err := s.cache.GetSetupSession(ctx, userID)
if err != nil {
return ErrTotpSetupExpired
}
if session == nil {
return ErrTotpSetupExpired
}
// Verify the setup token (constant-time comparison)
if subtle.ConstantTimeCompare([]byte(session.SetupToken), []byte(setupToken)) != 1 {
return ErrTotpSetupExpired
}
// Verify the TOTP code
if !totp.Validate(totpCode, session.Secret) {
return ErrTotpInvalidCode
}
setupSecretPrefix := "N/A"
if len(session.Secret) >= 4 {
setupSecretPrefix = session.Secret[:4]
}
slog.Debug("totp_complete_setup_before_encrypt",
"user_id", userID,
"secret_len", len(session.Secret),
"secret_prefix", setupSecretPrefix)
// Encrypt the secret
encryptedSecret, err := s.encryptor.Encrypt(session.Secret)
if err != nil {
return fmt.Errorf("encrypt totp secret: %w", err)
}
slog.Debug("totp_complete_setup_encrypted",
"user_id", userID,
"encrypted_len", len(encryptedSecret))
// Verify encryption by decrypting
decrypted, decErr := s.encryptor.Decrypt(encryptedSecret)
if decErr != nil {
slog.Debug("totp_complete_setup_verify_failed",
"user_id", userID,
"error", decErr)
} else {
decryptedPrefix := "N/A"
if len(decrypted) >= 4 {
decryptedPrefix = decrypted[:4]
}
slog.Debug("totp_complete_setup_verified",
"user_id", userID,
"original_len", len(session.Secret),
"decrypted_len", len(decrypted),
"match", session.Secret == decrypted,
"decrypted_prefix", decryptedPrefix)
}
// Update user with encrypted TOTP secret
if err := s.userRepo.UpdateTotpSecret(ctx, userID, &encryptedSecret); err != nil {
return fmt.Errorf("update totp secret: %w", err)
}
// Enable TOTP for the user
if err := s.userRepo.EnableTotp(ctx, userID); err != nil {
return fmt.Errorf("enable totp: %w", err)
}
// Clean up the setup session
_ = s.cache.DeleteSetupSession(ctx, userID)
return nil
}
// Disable disables TOTP for a user
// If email verification is enabled, emailCode is required; otherwise password is required
func (s *TotpService) Disable(ctx context.Context, userID int64, emailCode, password string) error {
// Get user
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if !user.TotpEnabled {
return ErrTotpNotSetup
}
// Verify identity based on email verification setting
if s.settingService.IsEmailVerifyEnabled(ctx) {
// Email verification enabled - verify email code
if emailCode == "" {
return ErrVerifyCodeRequired
}
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
return err
}
} else {
// Email verification disabled - verify password
if password == "" {
return ErrPasswordRequired
}
if !user.CheckPassword(password) {
return ErrPasswordIncorrect
}
}
// Disable TOTP
if err := s.userRepo.DisableTotp(ctx, userID); err != nil {
return fmt.Errorf("disable totp: %w", err)
}
return nil
}
// VerifyCode verifies a TOTP code for a user
func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) error {
slog.Debug("totp_verify_code_called",
"user_id", userID,
"code_len", len(code))
// Check rate limiting
attempts, err := s.cache.GetVerifyAttempts(ctx, userID)
if err == nil && attempts >= maxTotpAttempts {
return ErrTotpTooManyAttempts
}
// Get user
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
slog.Debug("totp_verify_get_user_failed",
"user_id", userID,
"error", err)
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
}
if !user.TotpEnabled || user.TotpSecretEncrypted == nil {
slog.Debug("totp_verify_not_setup",
"user_id", userID,
"enabled", user.TotpEnabled,
"has_secret", user.TotpSecretEncrypted != nil)
return ErrTotpNotSetup
}
slog.Debug("totp_verify_encrypted_secret",
"user_id", userID,
"encrypted_len", len(*user.TotpSecretEncrypted))
// Decrypt the secret
secret, err := s.encryptor.Decrypt(*user.TotpSecretEncrypted)
if err != nil {
slog.Debug("totp_verify_decrypt_failed",
"user_id", userID,
"error", err)
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
}
secretPrefix := "N/A"
if len(secret) >= 4 {
secretPrefix = secret[:4]
}
slog.Debug("totp_verify_decrypted",
"user_id", userID,
"secret_len", len(secret),
"secret_prefix", secretPrefix)
// Verify the code
valid := totp.Validate(code, secret)
slog.Debug("totp_verify_result",
"user_id", userID,
"valid", valid,
"secret_len", len(secret),
"secret_prefix", secretPrefix,
"server_time", time.Now().UTC().Format(time.RFC3339))
if !valid {
// Increment failed attempts
_, _ = s.cache.IncrementVerifyAttempts(ctx, userID)
return ErrTotpInvalidCode
}
// Clear attempt counter on success
_ = s.cache.ClearVerifyAttempts(ctx, userID)
return nil
}
// CreateLoginSession creates a temporary login session for 2FA
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
// Generate a random temp token
tempToken, err := generateRandomToken(32)
if err != nil {
return "", fmt.Errorf("generate temp token: %w", err)
}
session := &TotpLoginSession{
UserID: userID,
Email: email,
TokenExpiry: time.Now().Add(totpLoginTTL),
}
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
return "", fmt.Errorf("store login session: %w", err)
}
return tempToken, nil
}
// GetLoginSession retrieves a login session
func (s *TotpService) GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error) {
return s.cache.GetLoginSession(ctx, tempToken)
}
// DeleteLoginSession deletes a login session
func (s *TotpService) DeleteLoginSession(ctx context.Context, tempToken string) error {
return s.cache.DeleteLoginSession(ctx, tempToken)
}
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
func (s *TotpService) IsTotpEnabledForUser(ctx context.Context, userID int64) (bool, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return false, fmt.Errorf("get user: %w", err)
}
return user.TotpEnabled, nil
}
// MaskEmail masks an email address for display
func MaskEmail(email string) string {
if len(email) < 3 {
return "***"
}
atIdx := -1
for i, c := range email {
if c == '@' {
atIdx = i
break
}
}
if atIdx == -1 || atIdx < 1 {
return email[:1] + "***"
}
localPart := email[:atIdx]
domain := email[atIdx:]
if len(localPart) <= 2 {
return localPart[:1] + "***" + domain
}
return localPart[:1] + "***" + localPart[len(localPart)-1:] + domain
}
// generateRandomToken generates a random hex-encoded token
func generateRandomToken(byteLength int) (string, error) {
b := make([]byte, byteLength)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// VerificationMethod represents the method required for TOTP operations
type VerificationMethod struct {
Method string `json:"method"` // "email" or "password"
}
// GetVerificationMethod returns the verification method for TOTP operations
func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMethod {
if s.settingService.IsEmailVerifyEnabled(ctx) {
return &VerificationMethod{Method: "email"}
}
return &VerificationMethod{Method: "password"}
}
// SendVerifyCode sends an email verification code for TOTP operations
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
// Check if email verification is enabled
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
}
// Get user email
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
// Get site name for email
siteName := s.settingService.GetSiteName(ctx)
// Send verification code via queue
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
}
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
UsageCleanupStatusPending = "pending"
UsageCleanupStatusRunning = "running"
UsageCleanupStatusSucceeded = "succeeded"
UsageCleanupStatusFailed = "failed"
UsageCleanupStatusCanceled = "canceled"
)
// UsageCleanupFilters 定义清理任务过滤条件
// 时间范围为必填,其他字段可选
// JSON 序列化用于存储任务参数
//
// start_time/end_time 使用 RFC3339 时间格式
// 以 UTC 或用户时区解析后的时间为准
//
// 说明:
// - nil 表示未设置该过滤条件
// - 过滤条件均为精确匹配
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
// UsageCleanupTask 表示使用记录清理任务
// 状态包含 pending/running/succeeded/failed/canceled
type UsageCleanupTask struct {
ID int64
Status string
Filters UsageCleanupFilters
CreatedBy int64
DeletedRows int64
ErrorMsg *string
CanceledBy *int64
CanceledAt *time.Time
StartedAt *time.Time
FinishedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// UsageCleanupRepository 定义清理任务持久层接口
type UsageCleanupRepository interface {
CreateTask(ctx context.Context, task *UsageCleanupTask) error
ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error)
// ClaimNextPendingTask 抢占下一条可执行任务:
// - 优先 pending
// - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行
ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error)
// GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
GetTaskStatus(ctx context.Context, taskID int64) (string, error)
// UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示
UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error
// CancelTask 将任务标记为 canceled(仅允许 pending/running)
CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error)
MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error
MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error
DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error)
}
package service
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
usageCleanupWorkerName = "usage_cleanup_worker"
)
// UsageCleanupService 负责创建与执行使用记录清理任务
type UsageCleanupService struct {
repo UsageCleanupRepository
timingWheel *TimingWheelService
dashboard *DashboardAggregationService
cfg *config.Config
running int32
startOnce sync.Once
stopOnce sync.Once
workerCtx context.Context
workerCancel context.CancelFunc
}
func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
workerCtx, workerCancel := context.WithCancel(context.Background())
return &UsageCleanupService{
repo: repo,
timingWheel: timingWheel,
dashboard: dashboard,
cfg: cfg,
workerCtx: workerCtx,
workerCancel: workerCancel,
}
}
func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
var parts []string
parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339))
parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339))
if filters.UserID != nil {
parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID))
}
if filters.APIKeyID != nil {
parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID))
}
if filters.AccountID != nil {
parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID))
}
if filters.GroupID != nil {
parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID))
}
if filters.Model != nil {
parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
}
if filters.Stream != nil {
parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
}
if filters.BillingType != nil {
parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType))
}
return strings.Join(parts, " ")
}
func (s *UsageCleanupService) Start() {
if s == nil {
return
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
log.Printf("[UsageCleanup] not started (disabled)")
return
}
if s.repo == nil || s.timingWheel == nil {
log.Printf("[UsageCleanup] not started (missing deps)")
return
}
interval := s.workerInterval()
s.startOnce.Do(func() {
s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
})
}
func (s *UsageCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.workerCancel != nil {
s.workerCancel()
}
if s.timingWheel != nil {
s.timingWheel.Cancel(usageCleanupWorkerName)
}
log.Printf("[UsageCleanup] stopped")
})
}
func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
if s == nil || s.repo == nil {
return nil, nil, fmt.Errorf("cleanup service not ready")
}
return s.repo.ListTasks(ctx, params)
}
func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) {
if s == nil || s.repo == nil {
return nil, fmt.Errorf("cleanup service not ready")
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
}
if createdBy <= 0 {
return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
}
log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
sanitizeUsageCleanupFilters(&filters)
if err := s.validateFilters(filters); err != nil {
log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
return nil, err
}
task := &UsageCleanupTask{
Status: UsageCleanupStatusPending,
Filters: filters,
CreatedBy: createdBy,
}
if err := s.repo.CreateTask(ctx, task); err != nil {
log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
return nil, fmt.Errorf("create cleanup task: %w", err)
}
log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
go s.runOnce()
return task, nil
}
func (s *UsageCleanupService) runOnce() {
svc := s
if svc == nil {
return
}
if !atomic.CompareAndSwapInt32(&svc.running, 0, 1) {
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
return
}
defer atomic.StoreInt32(&svc.running, 0)
parent := context.Background()
if svc.workerCtx != nil {
parent = svc.workerCtx
}
ctx, cancel := context.WithTimeout(parent, svc.taskTimeout())
defer cancel()
task, err := svc.repo.ClaimNextPendingTask(ctx, int64(svc.taskTimeout().Seconds()))
if err != nil {
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
return
}
if task == nil {
log.Printf("[UsageCleanup] run_once done: no_task=true")
return
}
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
svc.executeTask(ctx, task)
}
func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
if task == nil {
return
}
batchSize := s.batchSize()
deletedTotal := task.DeletedRows
start := time.Now()
log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
var batchNum int
for {
if ctx != nil && ctx.Err() != nil {
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
return
}
canceled, err := s.isTaskCanceled(ctx, task.ID)
if err != nil {
s.markTaskFailed(task.ID, deletedTotal, err)
return
}
if canceled {
log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
return
}
batchNum++
deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
return
}
s.markTaskFailed(task.ID, deletedTotal, err)
return
}
deletedTotal += deleted
if deleted > 0 {
updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
}
cancel()
}
if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
}
if deleted == 0 || deleted < int64(batchSize) {
break
}
}
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
} else {
log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
}
if s.dashboard != nil {
if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
} else {
log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
}
}
}
func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) {
msg := strings.TrimSpace(err.Error())
if len(msg) > 500 {
msg = msg[:500]
}
log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
}
}
func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) {
if s == nil || s.repo == nil {
return false, fmt.Errorf("cleanup service not ready")
}
checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
status, err := s.repo.GetTaskStatus(checkCtx, taskID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
if status == UsageCleanupStatusCanceled {
log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
}
return status == UsageCleanupStatusCanceled, nil
}
func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required")
}
if filters.EndTime.Before(filters.StartTime) {
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date")
}
maxDays := s.maxRangeDays()
if maxDays > 0 {
delta := filters.EndTime.Sub(filters.StartTime)
if delta > time.Duration(maxDays)*24*time.Hour {
return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays))
}
}
return nil
}
func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error {
if s == nil || s.repo == nil {
return fmt.Errorf("cleanup service not ready")
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
}
if canceledBy <= 0 {
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller")
}
status, err := s.repo.GetTaskStatus(ctx, taskID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found")
}
return err
}
log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
ok, err := s.repo.CancelTask(ctx, taskID, canceledBy)
if err != nil {
return err
}
if !ok {
// 状态可能并发改变
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
return nil
}
func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
if filters == nil {
return
}
if filters.UserID != nil && *filters.UserID <= 0 {
filters.UserID = nil
}
if filters.APIKeyID != nil && *filters.APIKeyID <= 0 {
filters.APIKeyID = nil
}
if filters.AccountID != nil && *filters.AccountID <= 0 {
filters.AccountID = nil
}
if filters.GroupID != nil && *filters.GroupID <= 0 {
filters.GroupID = nil
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model == "" {
filters.Model = nil
} else {
filters.Model = &model
}
}
if filters.BillingType != nil && *filters.BillingType < 0 {
filters.BillingType = nil
}
}
func (s *UsageCleanupService) maxRangeDays() int {
if s == nil || s.cfg == nil {
return 31
}
if s.cfg.UsageCleanup.MaxRangeDays > 0 {
return s.cfg.UsageCleanup.MaxRangeDays
}
return 31
}
func (s *UsageCleanupService) batchSize() int {
if s == nil || s.cfg == nil {
return 5000
}
if s.cfg.UsageCleanup.BatchSize > 0 {
return s.cfg.UsageCleanup.BatchSize
}
return 5000
}
func (s *UsageCleanupService) workerInterval() time.Duration {
if s == nil || s.cfg == nil {
return 10 * time.Second
}
if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 {
return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second
}
return 10 * time.Second
}
func (s *UsageCleanupService) taskTimeout() time.Duration {
if s == nil || s.cfg == nil {
return 30 * time.Minute
}
if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 {
return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second
}
return 30 * time.Minute
}
package service
import (
"context"
"database/sql"
"errors"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type cleanupDeleteResponse struct {
deleted int64
err error
}
type cleanupDeleteCall struct {
filters UsageCleanupFilters
limit int
}
type cleanupMarkCall struct {
taskID int64
deletedRows int64
errMsg string
}
type cleanupRepoStub struct {
mu sync.Mutex
created []*UsageCleanupTask
createErr error
listTasks []UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
claimQueue []*UsageCleanupTask
claimErr error
deleteQueue []cleanupDeleteResponse
deleteCalls []cleanupDeleteCall
markSucceeded []cleanupMarkCall
markFailed []cleanupMarkCall
statusByID map[int64]string
statusErr error
progressCalls []cleanupMarkCall
updateErr error
cancelCalls []int64
cancelErr error
cancelResult *bool
markFailedErr error
}
type dashboardRepoStub struct {
recomputeErr error
}
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return s.recomputeErr
}
func (s *dashboardRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return time.Time{}, nil
}
func (s *dashboardRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return nil
}
func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.createErr != nil {
return s.createErr
}
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
if task.UpdatedAt.IsZero() {
task.UpdatedAt = task.CreatedAt
}
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.claimErr != nil {
return nil, s.claimErr
}
if len(s.claimQueue) == 0 {
return nil, nil
}
task := s.claimQueue[0]
s.claimQueue = s.claimQueue[1:]
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[task.ID] = UsageCleanupStatusRunning
return task, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusErr != nil {
return "", s.statusErr
}
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
s.mu.Lock()
defer s.mu.Unlock()
s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
if s.updateErr != nil {
return s.updateErr
}
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.cancelCalls = append(s.cancelCalls, taskID)
if s.cancelErr != nil {
return false, s.cancelErr
}
if s.cancelResult != nil {
ok := *s.cancelResult
if ok {
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusCanceled
}
return ok, nil
}
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
s.mu.Lock()
defer s.mu.Unlock()
s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusSucceeded
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg})
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusFailed
if s.markFailedErr != nil {
return s.markFailedErr
}
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit})
if len(s.deleteQueue) == 0 {
return 0, nil
}
resp := s.deleteQueue[0]
s.deleteQueue = s.deleteQueue[1:]
return resp.deleted, resp.err
}
func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(-1)
apiKeyID := int64(10)
model := " gpt-4 "
billingType := int8(-2)
filters := UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
Model: &model,
BillingType: &billingType,
}
task, err := svc.CreateTask(context.Background(), filters, 9)
require.NoError(t, err)
require.Equal(t, UsageCleanupStatusPending, task.Status)
require.Nil(t, task.Filters.UserID)
require.NotNil(t, task.Filters.APIKeyID)
require.Equal(t, apiKeyID, *task.Filters.APIKeyID)
require.NotNil(t, task.Filters.Model)
require.Equal(t, "gpt-4", *task.Filters.Model)
require.Nil(t, task.Filters.BillingType)
require.Equal(t, int64(9), task.CreatedBy)
}
func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 0)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(48 * time.Hour)
filters := UsageCleanupFilters{StartTime: start, EndTime: end}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
_, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
repo := &cleanupRepoStub{createErr: errors.New("db down")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "create cleanup task")
}
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(2 * time.Hour)
repo := &cleanupRepoStub{
claimQueue: []*UsageCleanupTask{
{ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}},
},
deleteQueue: []cleanupDeleteResponse{
{deleted: 2},
{deleted: 2},
{deleted: 1},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.runOnce()
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3)
require.Equal(t, 2, repo.deleteCalls[0].limit)
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
require.Equal(t, 2, repo.deleteCalls[0].limit)
require.Equal(t, start, repo.deleteCalls[0].filters.StartTime)
require.Equal(t, end, repo.deleteCalls[0].filters.EndTime)
}
func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
repo := &cleanupRepoStub{claimErr: errors.New("claim failed")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.runOnce()
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.markSucceeded)
require.Empty(t, repo.markFailed)
}
func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.running = 1
svc.runOnce()
}
func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
longMsg := strings.Repeat("x", 600)
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{err: errors.New(longMsg)},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 11,
Filters: UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markFailed, 1)
require.Equal(t, int64(11), repo.markFailed[0].taskID)
require.Equal(t, 500, len(repo.markFailed[0].errMsg))
}
func TestUsageCleanupServiceExecuteTaskProgressError(t *testing.T) {
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 2},
{deleted: 0},
},
updateErr: errors.New("update failed"),
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 8,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Len(t, repo.progressCalls, 1)
}
func TestUsageCleanupServiceExecuteTaskDeleteCanceled(t *testing.T) {
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{err: context.Canceled},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 12,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.markSucceeded)
require.Empty(t, repo.markFailed)
}
func TestUsageCleanupServiceExecuteTaskContextCanceled(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 9,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
svc.executeTask(ctx, task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.markSucceeded)
require.Empty(t, repo.markFailed)
require.Empty(t, repo.deleteCalls)
}
func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{err: errors.New("boom")},
},
markFailedErr: errors.New("update failed"),
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 13,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markFailed, 1)
require.Equal(t, int64(13), repo.markFailed[0].taskID)
}
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 0},
},
}
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
})
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
task := &UsageCleanupTask{
ID: 14,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
}
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 0},
},
}
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
})
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
task := &UsageCleanupTask{
ID: 15,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
}
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
3: UsageCleanupStatusCanceled,
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 3,
Filters: UsageCleanupFilters{
StartTime: time.Now().UTC(),
EndTime: time.Now().UTC().Add(time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.deleteCalls)
require.Empty(t, repo.markSucceeded)
require.Empty(t, repo.markFailed)
}
func TestUsageCleanupServiceCancelTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
5: UsageCleanupStatusPending,
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 5, 9)
require.NoError(t, err)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Equal(t, UsageCleanupStatusCanceled, repo.statusByID[5])
require.Len(t, repo.cancelCalls, 1)
}
func TestUsageCleanupServiceCancelTaskDisabled(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 1, 2)
require.Error(t, err)
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCancelTaskNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 999, 1)
require.Error(t, err)
require.Equal(t, http.StatusNotFound, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_TASK_NOT_FOUND", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCancelTaskStatusError(t *testing.T) {
repo := &cleanupRepoStub{statusErr: errors.New("status broken")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "status broken")
}
func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
7: UsageCleanupStatusSucceeded,
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 1)
require.Error(t, err)
require.Equal(t, http.StatusConflict, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
shouldCancel := false
repo := &cleanupRepoStub{
statusByID: map[int64]string{
7: UsageCleanupStatusPending,
},
cancelResult: &shouldCancel,
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 1)
require.Error(t, err)
require.Equal(t, http.StatusConflict, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCancelTaskRepoError(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
7: UsageCleanupStatusPending,
},
cancelErr: errors.New("cancel failed"),
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "cancel failed")
}
func TestUsageCleanupServiceCancelTaskInvalidCanceller(t *testing.T) {
repo := &cleanupRepoStub{
statusByID: map[int64]string{
7: UsageCleanupStatusRunning,
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
err := svc.CancelTask(context.Background(), 7, 0)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_INVALID_CANCELLER", infraerrors.Reason(err))
}
func TestUsageCleanupServiceListTasks(t *testing.T) {
repo := &cleanupRepoStub{
listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
listResult: &pagination.PaginationResult{
Total: 2,
Page: 1,
PageSize: 20,
Pages: 1,
},
}
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, tasks, 2)
require.Equal(t, int64(2), result.Total)
}
func TestUsageCleanupServiceListTasksNotReady(t *testing.T) {
var nilSvc *UsageCleanupService
_, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
_, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
}
func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) {
var nilSvc *UsageCleanupService
require.Equal(t, 31, nilSvc.maxRangeDays())
require.Equal(t, 5000, nilSvc.batchSize())
require.Equal(t, 10*time.Second, nilSvc.workerInterval())
require.Equal(t, 30*time.Minute, nilSvc.taskTimeout())
nilSvc.Start()
nilSvc.Stop()
repo := &cleanupRepoStub{}
cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled)
svcDisabled.Start()
svcDisabled.Stop()
timingWheel, err := NewTimingWheelService()
require.NoError(t, err)
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}}
svc := NewUsageCleanupService(repo, timingWheel, nil, cfg)
require.Equal(t, 5*time.Second, svc.workerInterval())
svc.Start()
svc.Stop()
cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback)
require.Equal(t, 31, svcFallback.maxRangeDays())
require.Equal(t, 5000, svcFallback.batchSize())
require.Equal(t, 10*time.Second, svcFallback.workerInterval())
svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback)
svcMissingDeps.Start()
}
func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
model := " "
apiKeyID := int64(-5)
accountID := int64(-1)
groupID := int64(-2)
filters := UsageCleanupFilters{
UserID: &apiKeyID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
}
sanitizeUsageCleanupFilters(&filters)
require.Nil(t, filters.UserID)
require.Nil(t, filters.APIKeyID)
require.Nil(t, filters.AccountID)
require.Nil(t, filters.GroupID)
require.Nil(t, filters.Model)
}
func TestDescribeUsageCleanupFiltersAllFields(t *testing.T) {
start := time.Date(2024, 2, 1, 10, 0, 0, 0, time.UTC)
end := start.Add(2 * time.Hour)
userID := int64(1)
apiKeyID := int64(2)
accountID := int64(3)
groupID := int64(4)
model := " gpt-4 "
stream := true
billingType := int8(2)
filters := UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
Stream: &stream,
BillingType: &billingType,
}
desc := describeUsageCleanupFilters(filters)
require.Equal(t, "start=2024-02-01T10:00:00Z end=2024-02-01T12:00:00Z user_id=1 api_key_id=2 account_id=3 group_id=4 model=gpt-4 stream=true billing_type=2", desc)
}
func TestUsageCleanupServiceIsTaskCanceledNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
canceled, err := svc.isTaskCanceled(context.Background(), 9)
require.NoError(t, err)
require.False(t, canceled)
}
func TestUsageCleanupServiceIsTaskCanceledError(t *testing.T) {
repo := &cleanupRepoStub{statusErr: errors.New("status err")}
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
_, err := svc.isTaskCanceled(context.Background(), 9)
require.Error(t, err)
require.Contains(t, err.Error(), "status err")
}
......@@ -21,6 +21,11 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间
APIKeys []APIKey
Subscriptions []UserSubscription
}
......
......@@ -38,6 +38,11 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// TOTP 相关方法
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
EnableTotp(ctx context.Context, userID int64) error
DisableTotp(ctx context.Context, userID int64) error
}
// UpdateProfileRequest 更新用户资料请求
......
......@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
......
package service
import (
"context"
"database/sql"
"time"
......@@ -57,6 +58,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
return svc
}
// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg)
svc.Start()
return svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute)
......@@ -64,6 +72,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
return svc
}
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService {
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
svc.Start()
return svc
}
// ProvideTimingWheelService creates and starts TimingWheelService
func ProvideTimingWheelService() (*TimingWheelService, error) {
svc, err := NewTimingWheelService()
......@@ -189,6 +204,8 @@ func ProvideOpsScheduledReportService(
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
// Start Pub/Sub subscriber for L1 cache invalidation across instances
apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
return apiKeyService
}
......@@ -246,10 +263,13 @@ var ProviderSet = wire.NewSet(
ProvideUpdateService,
ProvideTokenRefreshService,
ProvideAccountExpiryService,
ProvideSubscriptionExpiryService,
ProvideTimingWheelService,
ProvideDashboardAggregationService,
ProvideUsageCleanupService,
ProvideDeferredService,
NewAntigravityQuotaFetcher,
NewUserAttributeService,
NewUsageCache,
NewTotpService,
)
......@@ -46,7 +46,7 @@ func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
}
}
return trimmed, nil
return strings.TrimRight(trimmed, "/"), nil
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
......
......@@ -21,4 +21,31 @@ func TestValidateURLFormat(t *testing.T) {
if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil {
t.Fatalf("expected invalid port to fail")
}
// 验证末尾斜杠被移除
normalized, err := ValidateURLFormat("https://example.com/", false)
if err != nil {
t.Fatalf("expected trailing slash url to pass, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected trailing slash to be removed, got %s", normalized)
}
// 验证多个末尾斜杠被移除
normalized, err = ValidateURLFormat("https://example.com///", false)
if err != nil {
t.Fatalf("expected multiple trailing slashes to pass, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected all trailing slashes to be removed, got %s", normalized)
}
// 验证带路径的 URL 末尾斜杠被移除
normalized, err = ValidateURLFormat("https://example.com/api/v1/", false)
if err != nil {
t.Fatalf("expected trailing slash url with path to pass, got %v", err)
}
if normalized != "https://example.com/api/v1" {
t.Fatalf("expected trailing slash to be removed from path, got %s", normalized)
}
}
-- 兼容旧库:若尚未创建 user_allowed_groups,则确保 users.allowed_groups 存在,避免 007 迁移回填失败。
DO $$
BEGIN
IF to_regclass('public.user_allowed_groups') IS NULL THEN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'users'
) THEN
ALTER TABLE users
ADD COLUMN IF NOT EXISTS allowed_groups BIGINT[] DEFAULT NULL;
END IF;
END IF;
END $$;
-- 兼容缺失 users.allowed_groups 的老库,确保 007 回填可执行。
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'users'
) THEN
IF NOT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'users'
AND column_name = 'allowed_groups'
) THEN
IF NOT EXISTS (
SELECT 1
FROM schema_migrations
WHERE filename = '014_drop_legacy_allowed_groups.sql'
) THEN
ALTER TABLE users
ADD COLUMN IF NOT EXISTS allowed_groups BIGINT[] DEFAULT NULL;
END IF;
END IF;
END IF;
END $$;
-- 042_add_usage_cleanup_tasks.sql
-- 使用记录清理任务表
CREATE TABLE IF NOT EXISTS usage_cleanup_tasks (
id BIGSERIAL PRIMARY KEY,
status VARCHAR(20) NOT NULL,
filters JSONB NOT NULL,
created_by BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
deleted_rows BIGINT NOT NULL DEFAULT 0,
error_message TEXT,
started_at TIMESTAMPTZ,
finished_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_status_created_at
ON usage_cleanup_tasks(status, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_created_at
ON usage_cleanup_tasks(created_at DESC);
-- 043_add_usage_cleanup_cancel_audit.sql
-- usage_cleanup_tasks 取消任务审计字段
ALTER TABLE usage_cleanup_tasks
ADD COLUMN IF NOT EXISTS canceled_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
ADD COLUMN IF NOT EXISTS canceled_at TIMESTAMPTZ;
CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_canceled_at
ON usage_cleanup_tasks(canceled_at DESC);
-- 为 users 表添加 TOTP 双因素认证字段
ALTER TABLE users
ADD COLUMN IF NOT EXISTS totp_secret_encrypted TEXT DEFAULT NULL,
ADD COLUMN IF NOT EXISTS totp_enabled BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN IF NOT EXISTS totp_enabled_at TIMESTAMPTZ DEFAULT NULL;
COMMENT ON COLUMN users.totp_secret_encrypted IS 'AES-256-GCM 加密的 TOTP 密钥';
COMMENT ON COLUMN users.totp_enabled IS '是否启用 TOTP 双因素认证';
COMMENT ON COLUMN users.totp_enabled_at IS 'TOTP 启用时间';
-- 创建索引以支持快速查询启用 2FA 的用户
CREATE INDEX IF NOT EXISTS idx_users_totp_enabled ON users(totp_enabled) WHERE deleted_at IS NULL AND totp_enabled = true;
......@@ -251,6 +251,27 @@ dashboard_aggregation:
# 日聚合保留天数
daily_days: 730
# =============================================================================
# Usage Cleanup Task Configuration
# 使用记录清理任务配置(重启生效)
# =============================================================================
usage_cleanup:
# Enable cleanup task worker
# 启用清理任务执行器
enabled: true
# Max date range (days) per task
# 单次任务最大时间跨度(天)
max_range_days: 31
# Batch delete size
# 单批删除数量
batch_size: 5000
# Worker interval (seconds)
# 执行器轮询间隔(秒)
worker_interval_seconds: 10
# Task execution timeout (seconds)
# 单次任务最大执行时长(秒)
task_timeout_seconds: 1800
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置
......
......@@ -61,6 +61,18 @@ ADMIN_PASSWORD=
JWT_SECRET=
JWT_EXPIRE_HOUR=24
# -----------------------------------------------------------------------------
# TOTP (2FA) Configuration
# TOTP(双因素认证)配置
# -----------------------------------------------------------------------------
# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty, a
# random key will be generated on each startup, causing all existing TOTP
# configurations to become invalid (users won't be able to login with 2FA).
# Generate a secure key: openssl rand -hex 32
# 重要:设置固定的 TOTP 加密密钥。如果留空,每次启动将生成随机密钥,
# 导致现有的 TOTP 配置失效(用户无法使用双因素认证登录)。
TOTP_ENCRYPTION_KEY=
# -----------------------------------------------------------------------------
# Configuration File (Optional)
# -----------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment