Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0170d19f
Commit
0170d19f
authored
Feb 02, 2026
by
song
Browse files
merge upstream main
parent
7ade9baa
Changes
319
Show whitespace changes
Inline
Side-by-side
backend/internal/service/ratelimit_service_openai_test.go
0 → 100644
View file @
0170d19f
package
service
import
(
"net/http"
"testing"
"time"
)
func
TestCalculateOpenAI429ResetTime_7dExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
// ~4.5 days
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
// ~4.8 hours
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 384607 seconds from now
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_5hExhausted
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Simulate headers when 5h limit is exhausted (100% used)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should be approximately 3600 seconds from now
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Neither limit at 100%, should use the longer reset time
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"80"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"100000"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"90"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"5000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should use the max (100000 seconds from 7d window)
expectedDuration
:=
100000
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestCalculateOpenAI429ResetTime_NoCodexHeaders
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// No codex headers at all
headers
:=
http
.
Header
{}
headers
.
Set
(
"content-type"
,
"application/json"
)
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil resetAt when no codex headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_ReversedWindowOrder
(
t
*
testing
.
T
)
{
svc
:=
&
RateLimitService
{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// This is 5h
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"3600"
)
// 1 hour
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"300"
)
// 5 hours - smaller!
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"50"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"500000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"10080"
)
// 7 days - larger!
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt"
)
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration
:=
3600
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
}
func
TestNormalizedCodexLimits
(
t
*
testing
.
T
)
{
// Test the Normalize() method directly
pUsed
:=
100.0
pReset
:=
384607
pWindow
:=
10080
sUsed
:=
3.0
sReset
:=
17369
sWindow
:=
300
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
PrimaryWindowMinutes
:
&
pWindow
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
SecondaryWindowMinutes
:
&
sWindow
,
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Primary has larger window (10080 > 300), so primary should be 7d
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
384607
{
t
.
Errorf
(
"expected Reset7dSeconds=384607, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
3.0
{
t
.
Errorf
(
"expected Used5hPercent=3, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
17369
{
t
.
Errorf
(
"expected Reset5hSeconds=17369, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlyPrimaryData
(
t
*
testing
.
T
)
{
// Test when only primary has data, no window_minutes
pUsed
:=
80.0
pReset
:=
50000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
// No window_minutes, no secondary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
80.0
{
t
.
Errorf
(
"expected Used7dPercent=80, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
50000
{
t
.
Errorf
(
"expected Reset7dSeconds=50000, got %v"
,
normalized
.
Reset7dSeconds
)
}
// Secondary (5h) should be nil
if
normalized
.
Used5hPercent
!=
nil
{
t
.
Errorf
(
"expected Used5hPercent=nil, got %v"
,
*
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
!=
nil
{
t
.
Errorf
(
"expected Reset5hSeconds=nil, got %v"
,
*
normalized
.
Reset5hSeconds
)
}
}
func
TestNormalizedCodexLimits_OnlySecondaryData
(
t
*
testing
.
T
)
{
// Test when only secondary has data, no window_minutes
sUsed
:=
60.0
sReset
:=
3000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes, no primary data
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
60.0
{
t
.
Errorf
(
"expected Used5hPercent=60, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
3000
{
t
.
Errorf
(
"expected Reset5hSeconds=3000, got %v"
,
normalized
.
Reset5hSeconds
)
}
// Primary (7d) should be nil
if
normalized
.
Used7dPercent
!=
nil
{
t
.
Errorf
(
"expected Used7dPercent=nil, got %v"
,
*
normalized
.
Used7dPercent
)
}
}
func
TestNormalizedCodexLimits_BothDataNoWindowMinutes
(
t
*
testing
.
T
)
{
// Test when both have data but no window_minutes
pUsed
:=
100.0
pReset
:=
400000
sUsed
:=
50.0
sReset
:=
10000
snapshot
:=
&
OpenAICodexUsageSnapshot
{
PrimaryUsedPercent
:
&
pUsed
,
PrimaryResetAfterSeconds
:
&
pReset
,
SecondaryUsedPercent
:
&
sUsed
,
SecondaryResetAfterSeconds
:
&
sReset
,
// No window_minutes
}
normalized
:=
snapshot
.
Normalize
()
if
normalized
==
nil
{
t
.
Fatal
(
"expected non-nil normalized"
)
}
// Legacy assumption: primary=7d, secondary=5h
if
normalized
.
Used7dPercent
==
nil
||
*
normalized
.
Used7dPercent
!=
100.0
{
t
.
Errorf
(
"expected Used7dPercent=100, got %v"
,
normalized
.
Used7dPercent
)
}
if
normalized
.
Reset7dSeconds
==
nil
||
*
normalized
.
Reset7dSeconds
!=
400000
{
t
.
Errorf
(
"expected Reset7dSeconds=400000, got %v"
,
normalized
.
Reset7dSeconds
)
}
if
normalized
.
Used5hPercent
==
nil
||
*
normalized
.
Used5hPercent
!=
50.0
{
t
.
Errorf
(
"expected Used5hPercent=50, got %v"
,
normalized
.
Used5hPercent
)
}
if
normalized
.
Reset5hSeconds
==
nil
||
*
normalized
.
Reset5hSeconds
!=
10000
{
t
.
Errorf
(
"expected Reset5hSeconds=10000, got %v"
,
normalized
.
Reset5hSeconds
)
}
}
func
TestHandle429_AnthropicPlatformUnaffected
(
t
*
testing
.
T
)
{
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc
:=
&
RateLimitService
{}
// Simulate Anthropic 429 headers
headers
:=
http
.
Header
{}
headers
.
Set
(
"anthropic-ratelimit-unified-reset"
,
"1737820800"
)
// A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there are no x-codex-* headers
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil for Anthropic headers, got %v"
,
resetAt
)
}
}
func
TestCalculateOpenAI429ResetTime_UserProvidedScenario
(
t
*
testing
.
T
)
{
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc
:=
&
RateLimitService
{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"384607"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
// 7 days = 10080 minutes
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"3"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"17369"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
// 5 hours = 300 minutes
before
:=
time
.
Now
()
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
after
:=
time
.
Now
()
if
resetAt
==
nil
{
t
.
Fatal
(
"expected non-nil resetAt for user scenario"
)
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration
:=
384607
*
time
.
Second
minExpected
:=
before
.
Add
(
expectedDuration
)
maxExpected
:=
after
.
Add
(
expectedDuration
)
if
resetAt
.
Before
(
minExpected
)
||
resetAt
.
After
(
maxExpected
)
{
t
.
Errorf
(
"resetAt %v not in expected range [%v, %v]"
,
resetAt
,
minExpected
,
maxExpected
)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration
:=
resetAt
.
Sub
(
before
)
actualDays
:=
duration
.
Hours
()
/
24.0
// 384607 / 86400 = ~4.45 days
if
actualDays
<
4.4
||
actualDays
>
4.5
{
t
.
Errorf
(
"expected ~4.45 days, got %.2f days"
,
actualDays
)
}
t
.
Logf
(
"User scenario: reset_at=%v, duration=%.2f days"
,
resetAt
,
actualDays
)
}
func
TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset
(
t
*
testing
.
T
)
{
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc
:=
&
RateLimitService
{}
headers
:=
http
.
Header
{}
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
// No reset_after_seconds!
resetAt
:=
svc
.
calculateOpenAI429ResetTime
(
headers
)
// Should return nil since there's no reset time available
if
resetAt
!=
nil
{
t
.
Errorf
(
"expected nil when no reset_after_seconds, got %v"
,
resetAt
)
}
}
backend/internal/service/session_limit_cache.go
View file @
0170d19f
...
...
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
idleTimeouts
map
[
int64
]
time
.
Duration
)
(
map
[
int64
]
int
,
error
)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
...
...
backend/internal/service/setting_service.go
View file @
0170d19f
...
...
@@ -60,6 +60,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys
:=
[]
string
{
SettingKeyRegistrationEnabled
,
SettingKeyEmailVerifyEnabled
,
SettingKeyPromoCodeEnabled
,
SettingKeyPasswordResetEnabled
,
SettingKeyTotpEnabled
,
SettingKeyTurnstileEnabled
,
SettingKeyTurnstileSiteKey
,
SettingKeySiteName
,
...
...
@@ -69,6 +72,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyContactInfo
,
SettingKeyDocURL
,
SettingKeyHomeContent
,
SettingKeyHideCcsImportButton
,
SettingKeyPurchaseSubscriptionEnabled
,
SettingKeyPurchaseSubscriptionURL
,
SettingKeyLinuxDoConnectEnabled
,
}
...
...
@@ -84,9 +90,16 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled
=
s
.
cfg
!=
nil
&&
s
.
cfg
.
LinuxDo
.
Enabled
}
// Password reset requires email verification to be enabled
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
passwordResetEnabled
:=
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
return
&
PublicSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
emailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PasswordResetEnabled
:
passwordResetEnabled
,
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
...
...
@@ -96,6 +109,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
},
nil
}
...
...
@@ -123,6 +139,9 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
return
&
struct
{
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
TotpEnabled
bool
`json:"totp_enabled"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key,omitempty"`
SiteName
string
`json:"site_name"`
...
...
@@ -132,11 +151,17 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo
string
`json:"contact_info,omitempty"`
DocURL
string
`json:"doc_url,omitempty"`
HomeContent
string
`json:"home_content,omitempty"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url,omitempty"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version,omitempty"`
}{
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
PromoCodeEnabled
:
settings
.
PromoCodeEnabled
,
PasswordResetEnabled
:
settings
.
PasswordResetEnabled
,
TotpEnabled
:
settings
.
TotpEnabled
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
SiteName
:
settings
.
SiteName
,
...
...
@@ -146,6 +171,9 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
HomeContent
:
settings
.
HomeContent
,
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
s
.
version
,
},
nil
...
...
@@ -158,6 +186,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 注册设置
updates
[
SettingKeyRegistrationEnabled
]
=
strconv
.
FormatBool
(
settings
.
RegistrationEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyPromoCodeEnabled
]
=
strconv
.
FormatBool
(
settings
.
PromoCodeEnabled
)
updates
[
SettingKeyPasswordResetEnabled
]
=
strconv
.
FormatBool
(
settings
.
PasswordResetEnabled
)
updates
[
SettingKeyTotpEnabled
]
=
strconv
.
FormatBool
(
settings
.
TotpEnabled
)
// 邮件服务设置(只有非空才更新密码)
updates
[
SettingKeySMTPHost
]
=
settings
.
SMTPHost
...
...
@@ -193,6 +224,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyDocURL
]
=
settings
.
DocURL
updates
[
SettingKeyHomeContent
]
=
settings
.
HomeContent
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
updates
[
SettingKeyPurchaseSubscriptionEnabled
]
=
strconv
.
FormatBool
(
settings
.
PurchaseSubscriptionEnabled
)
updates
[
SettingKeyPurchaseSubscriptionURL
]
=
strings
.
TrimSpace
(
settings
.
PurchaseSubscriptionURL
)
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
...
...
@@ -243,6 +277,44 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
return
value
==
"true"
}
// IsPromoCodeEnabled 检查是否启用优惠码功能
func
(
s
*
SettingService
)
IsPromoCodeEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyPromoCodeEnabled
)
if
err
!=
nil
{
return
true
// 默认启用
}
return
value
!=
"false"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func
(
s
*
SettingService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
// Password reset requires email verification to be enabled
if
!
s
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyPasswordResetEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func
(
s
*
SettingService
)
IsTotpEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyTotpEnabled
)
if
err
!=
nil
{
return
false
// 默认关闭
}
return
value
==
"true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func
(
s
*
SettingService
)
IsTotpEncryptionKeyConfigured
()
bool
{
return
s
.
cfg
.
Totp
.
EncryptionKeyConfigured
}
// GetSiteName 获取网站名称
func
(
s
*
SettingService
)
GetSiteName
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
...
...
@@ -292,8 +364,11 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
defaults
:=
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeyPromoCodeEnabled
:
"true"
,
// 默认启用优惠码功能
SettingKeySiteName
:
"Sub2API"
,
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeySMTPPort
:
"587"
,
...
...
@@ -320,9 +395,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
emailVerifyEnabled
:=
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
emailVerifyEnabled
,
PromoCodeEnabled
:
settings
[
SettingKeyPromoCodeEnabled
]
!=
"false"
,
// 默认启用
PasswordResetEnabled
:
emailVerifyEnabled
&&
settings
[
SettingKeyPasswordResetEnabled
]
==
"true"
,
TotpEnabled
:
settings
[
SettingKeyTotpEnabled
]
==
"true"
,
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
...
...
@@ -339,6 +418,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
}
// 解析整数类型
...
...
backend/internal/service/settings_view.go
View file @
0170d19f
...
...
@@ -3,6 +3,9 @@ package service
type
SystemSettings
struct
{
RegistrationEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PasswordResetEnabled
bool
TotpEnabled
bool
// TOTP 双因素认证
SMTPHost
string
SMTPPort
int
...
...
@@ -32,6 +35,9 @@ type SystemSettings struct {
ContactInfo
string
DocURL
string
HomeContent
string
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
DefaultConcurrency
int
DefaultBalance
float64
...
...
@@ -57,6 +63,9 @@ type SystemSettings struct {
type
PublicSettings
struct
{
RegistrationEnabled
bool
EmailVerifyEnabled
bool
PromoCodeEnabled
bool
PasswordResetEnabled
bool
TotpEnabled
bool
// TOTP 双因素认证
TurnstileEnabled
bool
TurnstileSiteKey
string
SiteName
string
...
...
@@ -66,6 +75,11 @@ type PublicSettings struct {
ContactInfo
string
DocURL
string
HomeContent
string
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
LinuxDoOAuthEnabled
bool
Version
string
}
...
...
backend/internal/service/sticky_session_test.go
0 → 100644
View file @
0170d19f
//go:build unit
// Package service 提供 API 网关核心服务。
// 本文件包含 shouldClearStickySession 函数的单元测试,
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
//
// This file contains unit tests for the shouldClearStickySession function,
// verifying correct sticky session clearing behavior under various account states.
package
service
import
(
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
// 验证在以下情况下是否正确判断需要清理粘性会话:
// - nil 账号:不清理(返回 false)
// - 状态为错误或禁用:清理
// - 不可调度:清理
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
func
TestShouldClearStickySession
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
future
:=
now
.
Add
(
1
*
time
.
Hour
)
past
:=
now
.
Add
(
-
1
*
time
.
Hour
)
tests
:=
[]
struct
{
name
string
account
*
Account
want
bool
}{
{
name
:
"nil account"
,
account
:
nil
,
want
:
false
},
{
name
:
"status error"
,
account
:
&
Account
{
Status
:
StatusError
,
Schedulable
:
true
},
want
:
true
},
{
name
:
"status disabled"
,
account
:
&
Account
{
Status
:
StatusDisabled
,
Schedulable
:
true
},
want
:
true
},
{
name
:
"schedulable false"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
false
},
want
:
true
},
{
name
:
"temp unschedulable"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
,
TempUnschedulableUntil
:
&
future
},
want
:
true
},
{
name
:
"temp unschedulable expired"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
,
TempUnschedulableUntil
:
&
past
},
want
:
false
},
{
name
:
"active schedulable"
,
account
:
&
Account
{
Status
:
StatusActive
,
Schedulable
:
true
},
want
:
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
shouldClearStickySession
(
tt
.
account
))
})
}
}
backend/internal/service/subscription_expiry_service.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"log"
"sync"
"time"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type
SubscriptionExpiryService
struct
{
userSubRepo
UserSubscriptionRepository
interval
time
.
Duration
stopCh
chan
struct
{}
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
}
func
NewSubscriptionExpiryService
(
userSubRepo
UserSubscriptionRepository
,
interval
time
.
Duration
)
*
SubscriptionExpiryService
{
return
&
SubscriptionExpiryService
{
userSubRepo
:
userSubRepo
,
interval
:
interval
,
stopCh
:
make
(
chan
struct
{}),
}
}
func
(
s
*
SubscriptionExpiryService
)
Start
()
{
if
s
==
nil
||
s
.
userSubRepo
==
nil
||
s
.
interval
<=
0
{
return
}
s
.
wg
.
Add
(
1
)
go
func
()
{
defer
s
.
wg
.
Done
()
ticker
:=
time
.
NewTicker
(
s
.
interval
)
defer
ticker
.
Stop
()
s
.
runOnce
()
for
{
select
{
case
<-
ticker
.
C
:
s
.
runOnce
()
case
<-
s
.
stopCh
:
return
}
}
}()
}
func
(
s
*
SubscriptionExpiryService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
s
.
wg
.
Wait
()
}
func
(
s
*
SubscriptionExpiryService
)
runOnce
()
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
updated
,
err
:=
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[SubscriptionExpiry] Update expired subscriptions failed: %v"
,
err
)
return
}
if
updated
>
0
{
log
.
Printf
(
"[SubscriptionExpiry] Updated %d expired subscriptions"
,
updated
)
}
}
backend/internal/service/subscription_service.go
View file @
0170d19f
...
...
@@ -27,6 +27,7 @@ var (
ErrWeeklyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"WEEKLY_LIMIT_EXCEEDED"
,
"weekly usage limit exceeded"
)
ErrMonthlyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"MONTHLY_LIMIT_EXCEEDED"
,
"monthly usage limit exceeded"
)
ErrSubscriptionNilInput
=
infraerrors
.
BadRequest
(
"SUBSCRIPTION_NIL_INPUT"
,
"subscription input cannot be nil"
)
ErrAdjustWouldExpire
=
infraerrors
.
BadRequest
(
"ADJUST_WOULD_EXPIRE"
,
"adjustment would result in expired subscription (remaining days must be > 0)"
)
)
// SubscriptionService 订阅服务
...
...
@@ -308,24 +309,48 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
return
nil
}
// ExtendSubscription
延长订阅
// ExtendSubscription
调整订阅时长(正数延长,负数缩短)
func
(
s
*
SubscriptionService
)
ExtendSubscription
(
ctx
context
.
Context
,
subscriptionID
int64
,
days
int
)
(
*
UserSubscription
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetByID
(
ctx
,
subscriptionID
)
if
err
!=
nil
{
return
nil
,
ErrSubscriptionNotFound
}
// 限制
延长天数
// 限制
调整天数范围
if
days
>
MaxValidityDays
{
days
=
MaxValidityDays
}
if
days
<
-
MaxValidityDays
{
days
=
-
MaxValidityDays
}
now
:=
time
.
Now
()
isExpired
:=
!
sub
.
ExpiresAt
.
After
(
now
)
// 如果订阅已过期,不允许负向调整
if
isExpired
&&
days
<
0
{
return
nil
,
infraerrors
.
BadRequest
(
"CANNOT_SHORTEN_EXPIRED"
,
"cannot shorten an expired subscription"
)
}
// 计算新的过期时间
newExpiresAt
:=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
var
newExpiresAt
time
.
Time
if
isExpired
{
// 已过期:从当前时间开始增加天数
newExpiresAt
=
now
.
AddDate
(
0
,
0
,
days
)
}
else
{
// 未过期:从原过期时间增加/减少天数
newExpiresAt
=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
}
if
newExpiresAt
.
After
(
MaxExpiresAt
)
{
newExpiresAt
=
MaxExpiresAt
}
// 检查新的过期时间必须大于当前时间
if
!
newExpiresAt
.
After
(
now
)
{
return
nil
,
ErrAdjustWouldExpire
}
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -371,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
return
nil
,
err
}
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
nil
}
...
...
@@ -392,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
return
nil
,
nil
,
err
}
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
}
// List 获取所有订阅(分页,支持筛选)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
// List 获取所有订阅(分页,支持筛选
和排序
)
func
(
s
*
SubscriptionService
)
List
(
ctx
context
.
Context
,
page
,
pageSize
int
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
)
subs
,
pag
,
err
:=
s
.
userSubRepo
.
List
(
ctx
,
params
,
userID
,
groupID
,
status
,
sortBy
,
sortOrder
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
normalizeExpiredWindows
(
subs
)
normalizeSubscriptionStatus
(
subs
)
return
subs
,
pag
,
nil
}
...
...
@@ -429,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
}
}
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func
normalizeSubscriptionStatus
(
subs
[]
UserSubscription
)
{
now
:=
time
.
Now
()
for
i
:=
range
subs
{
sub
:=
&
subs
[
i
]
if
sub
.
Status
==
SubscriptionStatusActive
&&
!
sub
.
ExpiresAt
.
After
(
now
)
{
sub
.
Status
=
SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
...
...
@@ -647,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
return
progresses
,
nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func
(
s
*
SubscriptionService
)
UpdateExpiredSubscriptions
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
s
.
userSubRepo
.
BatchUpdateExpiredStatus
(
ctx
)
}
// ValidateSubscription 验证订阅是否有效
func
(
s
*
SubscriptionService
)
ValidateSubscription
(
ctx
context
.
Context
,
sub
*
UserSubscription
)
error
{
if
sub
.
Status
==
SubscriptionStatusExpired
{
...
...
backend/internal/service/token_cache_invalidator.go
View file @
0170d19f
package
service
import
"context"
import
(
"context"
"log/slog"
"strconv"
)
type
TokenCacheInvalidator
interface
{
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
...
...
@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
return
nil
}
var
cacheKey
string
var
keysToDelete
[]
string
accountIDKey
:=
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
switch
account
.
Platform
{
case
PlatformGemini
:
cacheKey
=
GeminiTokenCacheKey
(
account
)
// Gemini 可能有两种缓存键:project_id 或 account_id
// 首次获取 token 时可能没有 project_id,之后自动检测到 project_id 后会使用新 key
// 刷新时需要同时删除两种可能的 key,确保不会遗留旧缓存
keysToDelete
=
append
(
keysToDelete
,
GeminiTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"gemini:"
+
accountIDKey
)
case
PlatformAntigravity
:
cacheKey
=
AntigravityTokenCacheKey
(
account
)
// Antigravity 同样可能有两种缓存键
keysToDelete
=
append
(
keysToDelete
,
AntigravityTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"ag:"
+
accountIDKey
)
case
PlatformOpenAI
:
cacheKey
=
OpenAITokenCacheKey
(
account
)
keysToDelete
=
append
(
keysToDelete
,
OpenAITokenCacheKey
(
account
)
)
case
PlatformAnthropic
:
cacheKey
=
ClaudeTokenCacheKey
(
account
)
keysToDelete
=
append
(
keysToDelete
,
ClaudeTokenCacheKey
(
account
)
)
default
:
return
nil
}
return
c
.
cache
.
DeleteAccessToken
(
ctx
,
cacheKey
)
// 删除所有可能的缓存键(去重后)
seen
:=
make
(
map
[
string
]
bool
)
for
_
,
key
:=
range
keysToDelete
{
if
seen
[
key
]
{
continue
}
seen
[
key
]
=
true
if
err
:=
c
.
cache
.
DeleteAccessToken
(
ctx
,
key
);
err
!=
nil
{
slog
.
Warn
(
"token_cache_delete_failed"
,
"key"
,
key
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
nil
}
// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account
// 用于解决异步刷新任务与请求线程的竞态条件:
// 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存
//
// 返回值:
// - latestAccount: 从 DB 获取的最新 account(如果查询失败则返回 nil)
// - isStale: true 表示 token 已过时(应使用 latestAccount),false 表示可以使用当前 account
func
CheckTokenVersion
(
ctx
context
.
Context
,
account
*
Account
,
repo
AccountRepository
)
(
latestAccount
*
Account
,
isStale
bool
)
{
if
account
==
nil
||
repo
==
nil
{
return
nil
,
false
}
currentVersion
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
latestAccount
,
err
:=
repo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
||
latestAccount
==
nil
{
// 查询失败,默认允许缓存,不返回 latestAccount
return
nil
,
false
}
latestVersion
:=
latestAccount
.
GetCredentialAsInt64
(
"_token_version"
)
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
// 说明异步刷新任务已更新 token,当前 account 已过时
if
currentVersion
==
0
&&
latestVersion
>
0
{
slog
.
Debug
(
"token_version_stale_no_current_version"
,
"account_id"
,
account
.
ID
,
"latest_version"
,
latestVersion
)
return
latestAccount
,
true
}
// 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存
if
currentVersion
==
0
&&
latestVersion
==
0
{
return
latestAccount
,
false
}
// 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时
if
latestVersion
>
currentVersion
{
slog
.
Debug
(
"token_version_stale"
,
"account_id"
,
account
.
ID
,
"current_version"
,
currentVersion
,
"latest_version"
,
latestVersion
)
return
latestAccount
,
true
}
return
latestAccount
,
false
}
backend/internal/service/token_cache_invalidator_test.go
View file @
0170d19f
...
...
@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
},
cache
.
deletedKeys
)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
// 这是为了处理:首次获取 token 时可能没有 project_id,之后自动检测到后会使用新 key
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
,
"gemini:account:10"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_GeminiWithoutProjectID
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"gemini-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require
.
Equal
(
t
,
[]
string
{
"gemini:account:10"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
...
...
@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
,
"ag:account:99"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_AntigravityWithoutProjectID
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
99
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"ag-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require
.
Equal
(
t
,
[]
string
{
"ag:account:99"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
...
...
@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 新行为:删除失败只记录日志,不返回错误
// 这是因为缓存失效失败不应影响主业务流程
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
expectedErr
,
err
)
require
.
NoError
(
t
,
err
)
})
}
}
...
...
@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
}
// 新行为:Gemini 和 Antigravity 会同时删除基于 project_id 和 account_id 的键
expectedKeys
:=
[]
string
{
"gemini:gemini-proj"
,
"gemini:account:1"
,
"ag:ag-proj"
,
"ag:account:2"
,
"openai:account:3"
,
"claude:account:4"
,
}
...
...
@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
}
// ========== GetCredentialAsInt64 测试 ==========
func
TestAccount_GetCredentialAsInt64
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
credentials
map
[
string
]
any
key
string
expected
int64
}{
{
name
:
"int64_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
1737654321000
)},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"float64_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
float64
(
1737654321000
)},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"int_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
12345
},
key
:
"_token_version"
,
expected
:
12345
,
},
{
name
:
"string_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
"1737654321000"
},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"string_with_spaces"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
" 1737654321000 "
},
key
:
"_token_version"
,
expected
:
1737654321000
,
},
{
name
:
"nil_credentials"
,
credentials
:
nil
,
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"missing_key"
,
credentials
:
map
[
string
]
any
{
"other_key"
:
123
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"nil_value"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
nil
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"invalid_string"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
"not_a_number"
},
key
:
"_token_version"
,
expected
:
0
,
},
{
name
:
"empty_string"
,
credentials
:
map
[
string
]
any
{
"_token_version"
:
""
},
key
:
"_token_version"
,
expected
:
0
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Credentials
:
tt
.
credentials
}
result
:=
account
.
GetCredentialAsInt64
(
tt
.
key
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestAccount_GetCredentialAsInt64_NilAccount
(
t
*
testing
.
T
)
{
var
account
*
Account
result
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
require
.
Equal
(
t
,
int64
(
0
),
result
)
}
// ========== CheckTokenVersion 测试 ==========
func
TestCheckTokenVersion
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
latestAccount
*
Account
repoErr
error
expectedStale
bool
}{
{
name
:
"nil_account"
,
account
:
nil
,
latestAccount
:
nil
,
expectedStale
:
false
,
},
{
name
:
"no_version_in_account_but_db_has_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
true
,
// 当前 account 无版本但 DB 有,说明已被异步刷新,当前已过时
},
{
name
:
"both_no_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{},
},
expectedStale
:
false
,
// 两边都没有版本号,说明从未被异步刷新过,允许缓存
},
{
name
:
"same_version"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
false
,
},
{
name
:
"current_version_newer"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
200
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
expectedStale
:
false
,
},
{
name
:
"current_version_older_stale"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
200
)},
},
expectedStale
:
true
,
// 当前版本过时
},
{
name
:
"repo_error"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
nil
,
repoErr
:
errors
.
New
(
"db error"
),
expectedStale
:
false
,
// 查询失败,默认允许缓存
},
{
name
:
"repo_returns_nil"
,
account
:
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
},
latestAccount
:
nil
,
repoErr
:
nil
,
expectedStale
:
false
,
// 查询返回 nil,默认允许缓存
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
// 由于 CheckTokenVersion 接受 AccountRepository 接口,而创建完整的 mock 很繁琐
// 这里我们直接测试函数的核心逻辑来验证行为
if
tt
.
name
==
"nil_account"
{
_
,
isStale
:=
CheckTokenVersion
(
context
.
Background
(),
nil
,
nil
)
require
.
Equal
(
t
,
tt
.
expectedStale
,
isStale
)
return
}
// 模拟 CheckTokenVersion 的核心逻辑
account
:=
tt
.
account
currentVersion
:=
account
.
GetCredentialAsInt64
(
"_token_version"
)
// 模拟 repo 查询
latestAccount
:=
tt
.
latestAccount
if
tt
.
repoErr
!=
nil
||
latestAccount
==
nil
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
false
)
return
}
latestVersion
:=
latestAccount
.
GetCredentialAsInt64
(
"_token_version"
)
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
if
currentVersion
==
0
&&
latestVersion
>
0
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
true
)
return
}
// 情况2: 两边都没有版本号
if
currentVersion
==
0
&&
latestVersion
==
0
{
require
.
Equal
(
t
,
tt
.
expectedStale
,
false
)
return
}
// 情况3: 比较版本号
isStale
:=
latestVersion
>
currentVersion
require
.
Equal
(
t
,
tt
.
expectedStale
,
isStale
)
})
}
}
func
TestCheckTokenVersion_NilRepo
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
1
,
Credentials
:
map
[
string
]
any
{
"_token_version"
:
int64
(
100
)},
}
_
,
isStale
:=
CheckTokenVersion
(
context
.
Background
(),
account
,
nil
)
require
.
False
(
t
,
isStale
)
// nil repo,默认允许缓存
}
backend/internal/service/token_refresh_service.go
View file @
0170d19f
...
...
@@ -18,6 +18,7 @@ type TokenRefreshService struct {
refreshers
[]
TokenRefresher
cfg
*
config
.
TokenRefreshConfig
cacheInvalidator
TokenCacheInvalidator
schedulerCache
SchedulerCache
// 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
stopCh
chan
struct
{}
wg
sync
.
WaitGroup
...
...
@@ -31,12 +32,14 @@ func NewTokenRefreshService(
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cacheInvalidator
TokenCacheInvalidator
,
schedulerCache
SchedulerCache
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
accountRepo
:
accountRepo
,
cfg
:
&
cfg
.
TokenRefresh
,
cacheInvalidator
:
cacheInvalidator
,
schedulerCache
:
schedulerCache
,
stopCh
:
make
(
chan
struct
{}),
}
...
...
@@ -169,6 +172,10 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 如果有新凭证,先更新(即使有错误也要保存 token)
if
newCredentials
!=
nil
{
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
account
.
Credentials
=
newCredentials
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
...
...
@@ -194,6 +201,15 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
log
.
Printf
(
"[TokenRefresh] Token cache invalidated for account %d"
,
account
.
ID
)
}
}
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
if
s
.
schedulerCache
!=
nil
{
if
err
:=
s
.
schedulerCache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[TokenRefresh] Failed to sync scheduler cache for account %d: %v"
,
account
.
ID
,
err
)
}
else
{
log
.
Printf
(
"[TokenRefresh] Scheduler cache synced for account %d"
,
account
.
ID
)
}
}
return
nil
}
...
...
@@ -233,7 +249,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func
isNonRetryableRefreshError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
...
...
backend/internal/service/token_refresh_service_test.go
View file @
0170d19f
...
...
@@ -70,7 +70,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
5
,
Platform
:
PlatformGemini
,
...
...
@@ -98,7 +98,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
6
,
Platform
:
PlatformGemini
,
...
...
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
7
,
Platform
:
PlatformGemini
,
...
...
@@ -151,7 +151,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
8
,
Platform
:
PlatformAntigravity
,
...
...
@@ -179,7 +179,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
9
,
Platform
:
PlatformGemini
,
...
...
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformOpenAI
,
// OpenAI OAuth 账户
...
...
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
11
,
Platform
:
PlatformGemini
,
...
...
@@ -264,7 +264,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
12
,
Platform
:
PlatformGemini
,
...
...
@@ -291,7 +291,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
13
,
Platform
:
PlatformAntigravity
,
...
...
@@ -318,7 +318,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
14
,
Platform
:
PlatformAntigravity
,
...
...
backend/internal/service/totp_service.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log/slog"
"time"
"github.com/pquerna/otp/totp"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var
(
ErrTotpNotEnabled
=
infraerrors
.
BadRequest
(
"TOTP_NOT_ENABLED"
,
"totp feature is not enabled"
)
ErrTotpAlreadyEnabled
=
infraerrors
.
BadRequest
(
"TOTP_ALREADY_ENABLED"
,
"totp is already enabled for this account"
)
ErrTotpNotSetup
=
infraerrors
.
BadRequest
(
"TOTP_NOT_SETUP"
,
"totp is not set up for this account"
)
ErrTotpInvalidCode
=
infraerrors
.
BadRequest
(
"TOTP_INVALID_CODE"
,
"invalid totp code"
)
ErrTotpSetupExpired
=
infraerrors
.
BadRequest
(
"TOTP_SETUP_EXPIRED"
,
"totp setup session expired"
)
ErrTotpTooManyAttempts
=
infraerrors
.
TooManyRequests
(
"TOTP_TOO_MANY_ATTEMPTS"
,
"too many verification attempts, please try again later"
)
ErrVerifyCodeRequired
=
infraerrors
.
BadRequest
(
"VERIFY_CODE_REQUIRED"
,
"email verification code is required"
)
ErrPasswordRequired
=
infraerrors
.
BadRequest
(
"PASSWORD_REQUIRED"
,
"password is required"
)
)
// TotpCache defines cache operations for TOTP service
type
TotpCache
interface
{
// Setup session methods
GetSetupSession
(
ctx
context
.
Context
,
userID
int64
)
(
*
TotpSetupSession
,
error
)
SetSetupSession
(
ctx
context
.
Context
,
userID
int64
,
session
*
TotpSetupSession
,
ttl
time
.
Duration
)
error
DeleteSetupSession
(
ctx
context
.
Context
,
userID
int64
)
error
// Login session methods (for 2FA login flow)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
TotpLoginSession
,
error
)
SetLoginSession
(
ctx
context
.
Context
,
tempToken
string
,
session
*
TotpLoginSession
,
ttl
time
.
Duration
)
error
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
// Rate limiting
IncrementVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
GetVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
ClearVerifyAttempts
(
ctx
context
.
Context
,
userID
int64
)
error
}
// SecretEncryptor defines encryption operations for TOTP secrets
type
SecretEncryptor
interface
{
Encrypt
(
plaintext
string
)
(
string
,
error
)
Decrypt
(
ciphertext
string
)
(
string
,
error
)
}
// TotpSetupSession represents a TOTP setup session
type
TotpSetupSession
struct
{
Secret
string
// Plain text TOTP secret (not encrypted yet)
SetupToken
string
// Random token to verify setup request
CreatedAt
time
.
Time
}
// TotpLoginSession represents a pending 2FA login session
type
TotpLoginSession
struct
{
UserID
int64
Email
string
TokenExpiry
time
.
Time
}
// TotpStatus represents the TOTP status for a user
type
TotpStatus
struct
{
Enabled
bool
`json:"enabled"`
EnabledAt
*
time
.
Time
`json:"enabled_at,omitempty"`
FeatureEnabled
bool
`json:"feature_enabled"`
}
// TotpSetupResponse represents the response for initiating TOTP setup
type
TotpSetupResponse
struct
{
Secret
string
`json:"secret"`
QRCodeURL
string
`json:"qr_code_url"`
SetupToken
string
`json:"setup_token"`
Countdown
int
`json:"countdown"`
// seconds until setup expires
}
const
(
totpSetupTTL
=
5
*
time
.
Minute
totpLoginTTL
=
5
*
time
.
Minute
totpAttemptsTTL
=
15
*
time
.
Minute
maxTotpAttempts
=
5
totpIssuer
=
"Sub2API"
)
// TotpService handles TOTP operations
type
TotpService
struct
{
userRepo
UserRepository
encryptor
SecretEncryptor
cache
TotpCache
settingService
*
SettingService
emailService
*
EmailService
emailQueueService
*
EmailQueueService
}
// NewTotpService creates a new TOTP service
func
NewTotpService
(
userRepo
UserRepository
,
encryptor
SecretEncryptor
,
cache
TotpCache
,
settingService
*
SettingService
,
emailService
*
EmailService
,
emailQueueService
*
EmailQueueService
,
)
*
TotpService
{
return
&
TotpService
{
userRepo
:
userRepo
,
encryptor
:
encryptor
,
cache
:
cache
,
settingService
:
settingService
,
emailService
:
emailService
,
emailQueueService
:
emailQueueService
,
}
}
// GetStatus returns the TOTP status for a user
func
(
s
*
TotpService
)
GetStatus
(
ctx
context
.
Context
,
userID
int64
)
(
*
TotpStatus
,
error
)
{
featureEnabled
:=
s
.
settingService
.
IsTotpEnabled
(
ctx
)
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
return
&
TotpStatus
{
Enabled
:
user
.
TotpEnabled
,
EnabledAt
:
user
.
TotpEnabledAt
,
FeatureEnabled
:
featureEnabled
,
},
nil
}
// InitiateSetup starts the TOTP setup process
// If email verification is enabled, emailCode is required; otherwise password is required
func
(
s
*
TotpService
)
InitiateSetup
(
ctx
context
.
Context
,
userID
int64
,
emailCode
,
password
string
)
(
*
TotpSetupResponse
,
error
)
{
// Check if TOTP feature is enabled globally
if
!
s
.
settingService
.
IsTotpEnabled
(
ctx
)
{
return
nil
,
ErrTotpNotEnabled
}
// Get user and check if TOTP is already enabled
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
user
.
TotpEnabled
{
return
nil
,
ErrTotpAlreadyEnabled
}
// Verify identity based on email verification setting
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// Email verification enabled - verify email code
if
emailCode
==
""
{
return
nil
,
ErrVerifyCodeRequired
}
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
user
.
Email
,
emailCode
);
err
!=
nil
{
return
nil
,
err
}
}
else
{
// Email verification disabled - verify password
if
password
==
""
{
return
nil
,
ErrPasswordRequired
}
if
!
user
.
CheckPassword
(
password
)
{
return
nil
,
ErrPasswordIncorrect
}
}
// Generate a new TOTP key
key
,
err
:=
totp
.
Generate
(
totp
.
GenerateOpts
{
Issuer
:
totpIssuer
,
AccountName
:
user
.
Email
,
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate totp key: %w"
,
err
)
}
// Generate a random setup token
setupToken
,
err
:=
generateRandomToken
(
32
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate setup token: %w"
,
err
)
}
// Store the setup session in cache
session
:=
&
TotpSetupSession
{
Secret
:
key
.
Secret
(),
SetupToken
:
setupToken
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetSetupSession
(
ctx
,
userID
,
session
,
totpSetupTTL
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"store setup session: %w"
,
err
)
}
return
&
TotpSetupResponse
{
Secret
:
key
.
Secret
(),
QRCodeURL
:
key
.
URL
(),
SetupToken
:
setupToken
,
Countdown
:
int
(
totpSetupTTL
.
Seconds
()),
},
nil
}
// CompleteSetup completes the TOTP setup by verifying the code
func
(
s
*
TotpService
)
CompleteSetup
(
ctx
context
.
Context
,
userID
int64
,
totpCode
,
setupToken
string
)
error
{
// Check if TOTP feature is enabled globally
if
!
s
.
settingService
.
IsTotpEnabled
(
ctx
)
{
return
ErrTotpNotEnabled
}
// Get the setup session
session
,
err
:=
s
.
cache
.
GetSetupSession
(
ctx
,
userID
)
if
err
!=
nil
{
return
ErrTotpSetupExpired
}
if
session
==
nil
{
return
ErrTotpSetupExpired
}
// Verify the setup token (constant-time comparison)
if
subtle
.
ConstantTimeCompare
([]
byte
(
session
.
SetupToken
),
[]
byte
(
setupToken
))
!=
1
{
return
ErrTotpSetupExpired
}
// Verify the TOTP code
if
!
totp
.
Validate
(
totpCode
,
session
.
Secret
)
{
return
ErrTotpInvalidCode
}
setupSecretPrefix
:=
"N/A"
if
len
(
session
.
Secret
)
>=
4
{
setupSecretPrefix
=
session
.
Secret
[
:
4
]
}
slog
.
Debug
(
"totp_complete_setup_before_encrypt"
,
"user_id"
,
userID
,
"secret_len"
,
len
(
session
.
Secret
),
"secret_prefix"
,
setupSecretPrefix
)
// Encrypt the secret
encryptedSecret
,
err
:=
s
.
encryptor
.
Encrypt
(
session
.
Secret
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"encrypt totp secret: %w"
,
err
)
}
slog
.
Debug
(
"totp_complete_setup_encrypted"
,
"user_id"
,
userID
,
"encrypted_len"
,
len
(
encryptedSecret
))
// Verify encryption by decrypting
decrypted
,
decErr
:=
s
.
encryptor
.
Decrypt
(
encryptedSecret
)
if
decErr
!=
nil
{
slog
.
Debug
(
"totp_complete_setup_verify_failed"
,
"user_id"
,
userID
,
"error"
,
decErr
)
}
else
{
decryptedPrefix
:=
"N/A"
if
len
(
decrypted
)
>=
4
{
decryptedPrefix
=
decrypted
[
:
4
]
}
slog
.
Debug
(
"totp_complete_setup_verified"
,
"user_id"
,
userID
,
"original_len"
,
len
(
session
.
Secret
),
"decrypted_len"
,
len
(
decrypted
),
"match"
,
session
.
Secret
==
decrypted
,
"decrypted_prefix"
,
decryptedPrefix
)
}
// Update user with encrypted TOTP secret
if
err
:=
s
.
userRepo
.
UpdateTotpSecret
(
ctx
,
userID
,
&
encryptedSecret
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update totp secret: %w"
,
err
)
}
// Enable TOTP for the user
if
err
:=
s
.
userRepo
.
EnableTotp
(
ctx
,
userID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"enable totp: %w"
,
err
)
}
// Clean up the setup session
_
=
s
.
cache
.
DeleteSetupSession
(
ctx
,
userID
)
return
nil
}
// Disable disables TOTP for a user
// If email verification is enabled, emailCode is required; otherwise password is required
func
(
s
*
TotpService
)
Disable
(
ctx
context
.
Context
,
userID
int64
,
emailCode
,
password
string
)
error
{
// Get user
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
if
!
user
.
TotpEnabled
{
return
ErrTotpNotSetup
}
// Verify identity based on email verification setting
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// Email verification enabled - verify email code
if
emailCode
==
""
{
return
ErrVerifyCodeRequired
}
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
user
.
Email
,
emailCode
);
err
!=
nil
{
return
err
}
}
else
{
// Email verification disabled - verify password
if
password
==
""
{
return
ErrPasswordRequired
}
if
!
user
.
CheckPassword
(
password
)
{
return
ErrPasswordIncorrect
}
}
// Disable TOTP
if
err
:=
s
.
userRepo
.
DisableTotp
(
ctx
,
userID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"disable totp: %w"
,
err
)
}
return
nil
}
// VerifyCode verifies a TOTP code for a user
func
(
s
*
TotpService
)
VerifyCode
(
ctx
context
.
Context
,
userID
int64
,
code
string
)
error
{
slog
.
Debug
(
"totp_verify_code_called"
,
"user_id"
,
userID
,
"code_len"
,
len
(
code
))
// Check rate limiting
attempts
,
err
:=
s
.
cache
.
GetVerifyAttempts
(
ctx
,
userID
)
if
err
==
nil
&&
attempts
>=
maxTotpAttempts
{
return
ErrTotpTooManyAttempts
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
slog
.
Debug
(
"totp_verify_get_user_failed"
,
"user_id"
,
userID
,
"error"
,
err
)
return
infraerrors
.
InternalServer
(
"TOTP_VERIFY_ERROR"
,
"failed to verify totp code"
)
}
if
!
user
.
TotpEnabled
||
user
.
TotpSecretEncrypted
==
nil
{
slog
.
Debug
(
"totp_verify_not_setup"
,
"user_id"
,
userID
,
"enabled"
,
user
.
TotpEnabled
,
"has_secret"
,
user
.
TotpSecretEncrypted
!=
nil
)
return
ErrTotpNotSetup
}
slog
.
Debug
(
"totp_verify_encrypted_secret"
,
"user_id"
,
userID
,
"encrypted_len"
,
len
(
*
user
.
TotpSecretEncrypted
))
// Decrypt the secret
secret
,
err
:=
s
.
encryptor
.
Decrypt
(
*
user
.
TotpSecretEncrypted
)
if
err
!=
nil
{
slog
.
Debug
(
"totp_verify_decrypt_failed"
,
"user_id"
,
userID
,
"error"
,
err
)
return
infraerrors
.
InternalServer
(
"TOTP_VERIFY_ERROR"
,
"failed to verify totp code"
)
}
secretPrefix
:=
"N/A"
if
len
(
secret
)
>=
4
{
secretPrefix
=
secret
[
:
4
]
}
slog
.
Debug
(
"totp_verify_decrypted"
,
"user_id"
,
userID
,
"secret_len"
,
len
(
secret
),
"secret_prefix"
,
secretPrefix
)
// Verify the code
valid
:=
totp
.
Validate
(
code
,
secret
)
slog
.
Debug
(
"totp_verify_result"
,
"user_id"
,
userID
,
"valid"
,
valid
,
"secret_len"
,
len
(
secret
),
"secret_prefix"
,
secretPrefix
,
"server_time"
,
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
))
if
!
valid
{
// Increment failed attempts
_
,
_
=
s
.
cache
.
IncrementVerifyAttempts
(
ctx
,
userID
)
return
ErrTotpInvalidCode
}
// Clear attempt counter on success
_
=
s
.
cache
.
ClearVerifyAttempts
(
ctx
,
userID
)
return
nil
}
// CreateLoginSession creates a temporary login session for 2FA
func
(
s
*
TotpService
)
CreateLoginSession
(
ctx
context
.
Context
,
userID
int64
,
email
string
)
(
string
,
error
)
{
// Generate a random temp token
tempToken
,
err
:=
generateRandomToken
(
32
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate temp token: %w"
,
err
)
}
session
:=
&
TotpLoginSession
{
UserID
:
userID
,
Email
:
email
,
TokenExpiry
:
time
.
Now
()
.
Add
(
totpLoginTTL
),
}
if
err
:=
s
.
cache
.
SetLoginSession
(
ctx
,
tempToken
,
session
,
totpLoginTTL
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"store login session: %w"
,
err
)
}
return
tempToken
,
nil
}
// GetLoginSession retrieves a login session
func
(
s
*
TotpService
)
GetLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
(
*
TotpLoginSession
,
error
)
{
return
s
.
cache
.
GetLoginSession
(
ctx
,
tempToken
)
}
// DeleteLoginSession deletes a login session
func
(
s
*
TotpService
)
DeleteLoginSession
(
ctx
context
.
Context
,
tempToken
string
)
error
{
return
s
.
cache
.
DeleteLoginSession
(
ctx
,
tempToken
)
}
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
func
(
s
*
TotpService
)
IsTotpEnabledForUser
(
ctx
context
.
Context
,
userID
int64
)
(
bool
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
false
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
return
user
.
TotpEnabled
,
nil
}
// MaskEmail masks an email address for display
func
MaskEmail
(
email
string
)
string
{
if
len
(
email
)
<
3
{
return
"***"
}
atIdx
:=
-
1
for
i
,
c
:=
range
email
{
if
c
==
'@'
{
atIdx
=
i
break
}
}
if
atIdx
==
-
1
||
atIdx
<
1
{
return
email
[
:
1
]
+
"***"
}
localPart
:=
email
[
:
atIdx
]
domain
:=
email
[
atIdx
:
]
if
len
(
localPart
)
<=
2
{
return
localPart
[
:
1
]
+
"***"
+
domain
}
return
localPart
[
:
1
]
+
"***"
+
localPart
[
len
(
localPart
)
-
1
:
]
+
domain
}
// generateRandomToken generates a random hex-encoded token
func
generateRandomToken
(
byteLength
int
)
(
string
,
error
)
{
b
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
b
),
nil
}
// VerificationMethod represents the method required for TOTP operations
type
VerificationMethod
struct
{
Method
string
`json:"method"`
// "email" or "password"
}
// GetVerificationMethod returns the verification method for TOTP operations
func
(
s
*
TotpService
)
GetVerificationMethod
(
ctx
context
.
Context
)
*
VerificationMethod
{
if
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
&
VerificationMethod
{
Method
:
"email"
}
}
return
&
VerificationMethod
{
Method
:
"password"
}
}
// SendVerifyCode sends an email verification code for TOTP operations
func
(
s
*
TotpService
)
SendVerifyCode
(
ctx
context
.
Context
,
userID
int64
)
error
{
// Check if email verification is enabled
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_NOT_ENABLED"
,
"email verification is not enabled"
)
}
// Get user email
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// Get site name for email
siteName
:=
s
.
settingService
.
GetSiteName
(
ctx
)
// Send verification code via queue
return
s
.
emailQueueService
.
EnqueueVerifyCode
(
user
.
Email
,
siteName
)
}
backend/internal/service/usage_cleanup.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const
(
UsageCleanupStatusPending
=
"pending"
UsageCleanupStatusRunning
=
"running"
UsageCleanupStatusSucceeded
=
"succeeded"
UsageCleanupStatusFailed
=
"failed"
UsageCleanupStatusCanceled
=
"canceled"
)
// UsageCleanupFilters 定义清理任务过滤条件
// 时间范围为必填,其他字段可选
// JSON 序列化用于存储任务参数
//
// start_time/end_time 使用 RFC3339 时间格式
// 以 UTC 或用户时区解析后的时间为准
//
// 说明:
// - nil 表示未设置该过滤条件
// - 过滤条件均为精确匹配
type
UsageCleanupFilters
struct
{
StartTime
time
.
Time
`json:"start_time"`
EndTime
time
.
Time
`json:"end_time"`
UserID
*
int64
`json:"user_id,omitempty"`
APIKeyID
*
int64
`json:"api_key_id,omitempty"`
AccountID
*
int64
`json:"account_id,omitempty"`
GroupID
*
int64
`json:"group_id,omitempty"`
Model
*
string
`json:"model,omitempty"`
Stream
*
bool
`json:"stream,omitempty"`
BillingType
*
int8
`json:"billing_type,omitempty"`
}
// UsageCleanupTask 表示使用记录清理任务
// 状态包含 pending/running/succeeded/failed/canceled
type
UsageCleanupTask
struct
{
ID
int64
Status
string
Filters
UsageCleanupFilters
CreatedBy
int64
DeletedRows
int64
ErrorMsg
*
string
CanceledBy
*
int64
CanceledAt
*
time
.
Time
StartedAt
*
time
.
Time
FinishedAt
*
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
}
// UsageCleanupRepository 定义清理任务持久层接口
type
UsageCleanupRepository
interface
{
CreateTask
(
ctx
context
.
Context
,
task
*
UsageCleanupTask
)
error
ListTasks
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
// ClaimNextPendingTask 抢占下一条可执行任务:
// - 优先 pending
// - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行
ClaimNextPendingTask
(
ctx
context
.
Context
,
staleRunningAfterSeconds
int64
)
(
*
UsageCleanupTask
,
error
)
// GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
GetTaskStatus
(
ctx
context
.
Context
,
taskID
int64
)
(
string
,
error
)
// UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示
UpdateTaskProgress
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
// CancelTask 将任务标记为 canceled(仅允许 pending/running)
CancelTask
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
(
bool
,
error
)
MarkTaskSucceeded
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
MarkTaskFailed
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
,
errorMsg
string
)
error
DeleteUsageLogsBatch
(
ctx
context
.
Context
,
filters
UsageCleanupFilters
,
limit
int
)
(
int64
,
error
)
}
backend/internal/service/usage_cleanup_service.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"database/sql"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const
(
usageCleanupWorkerName
=
"usage_cleanup_worker"
)
// UsageCleanupService 负责创建与执行使用记录清理任务
type
UsageCleanupService
struct
{
repo
UsageCleanupRepository
timingWheel
*
TimingWheelService
dashboard
*
DashboardAggregationService
cfg
*
config
.
Config
running
int32
startOnce
sync
.
Once
stopOnce
sync
.
Once
workerCtx
context
.
Context
workerCancel
context
.
CancelFunc
}
func
NewUsageCleanupService
(
repo
UsageCleanupRepository
,
timingWheel
*
TimingWheelService
,
dashboard
*
DashboardAggregationService
,
cfg
*
config
.
Config
)
*
UsageCleanupService
{
workerCtx
,
workerCancel
:=
context
.
WithCancel
(
context
.
Background
())
return
&
UsageCleanupService
{
repo
:
repo
,
timingWheel
:
timingWheel
,
dashboard
:
dashboard
,
cfg
:
cfg
,
workerCtx
:
workerCtx
,
workerCancel
:
workerCancel
,
}
}
func
describeUsageCleanupFilters
(
filters
UsageCleanupFilters
)
string
{
var
parts
[]
string
parts
=
append
(
parts
,
"start="
+
filters
.
StartTime
.
UTC
()
.
Format
(
time
.
RFC3339
))
parts
=
append
(
parts
,
"end="
+
filters
.
EndTime
.
UTC
()
.
Format
(
time
.
RFC3339
))
if
filters
.
UserID
!=
nil
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
"user_id=%d"
,
*
filters
.
UserID
))
}
if
filters
.
APIKeyID
!=
nil
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
"api_key_id=%d"
,
*
filters
.
APIKeyID
))
}
if
filters
.
AccountID
!=
nil
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
"account_id=%d"
,
*
filters
.
AccountID
))
}
if
filters
.
GroupID
!=
nil
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
"group_id=%d"
,
*
filters
.
GroupID
))
}
if
filters
.
Model
!=
nil
{
parts
=
append
(
parts
,
"model="
+
strings
.
TrimSpace
(
*
filters
.
Model
))
}
if
filters
.
Stream
!=
nil
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
"stream=%t"
,
*
filters
.
Stream
))
}
if
filters
.
BillingType
!=
nil
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
"billing_type=%d"
,
*
filters
.
BillingType
))
}
return
strings
.
Join
(
parts
,
" "
)
}
func
(
s
*
UsageCleanupService
)
Start
()
{
if
s
==
nil
{
return
}
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
UsageCleanup
.
Enabled
{
log
.
Printf
(
"[UsageCleanup] not started (disabled)"
)
return
}
if
s
.
repo
==
nil
||
s
.
timingWheel
==
nil
{
log
.
Printf
(
"[UsageCleanup] not started (missing deps)"
)
return
}
interval
:=
s
.
workerInterval
()
s
.
startOnce
.
Do
(
func
()
{
s
.
timingWheel
.
ScheduleRecurring
(
usageCleanupWorkerName
,
interval
,
s
.
runOnce
)
log
.
Printf
(
"[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)"
,
interval
,
s
.
maxRangeDays
(),
s
.
batchSize
(),
s
.
taskTimeout
())
})
}
func
(
s
*
UsageCleanupService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
if
s
.
workerCancel
!=
nil
{
s
.
workerCancel
()
}
if
s
.
timingWheel
!=
nil
{
s
.
timingWheel
.
Cancel
(
usageCleanupWorkerName
)
}
log
.
Printf
(
"[UsageCleanup] stopped"
)
})
}
func
(
s
*
UsageCleanupService
)
ListTasks
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"cleanup service not ready"
)
}
return
s
.
repo
.
ListTasks
(
ctx
,
params
)
}
func
(
s
*
UsageCleanupService
)
CreateTask
(
ctx
context
.
Context
,
filters
UsageCleanupFilters
,
createdBy
int64
)
(
*
UsageCleanupTask
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
nil
,
fmt
.
Errorf
(
"cleanup service not ready"
)
}
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
UsageCleanup
.
Enabled
{
return
nil
,
infraerrors
.
New
(
http
.
StatusServiceUnavailable
,
"USAGE_CLEANUP_DISABLED"
,
"usage cleanup is disabled"
)
}
if
createdBy
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"USAGE_CLEANUP_INVALID_CREATOR"
,
"invalid creator"
)
}
log
.
Printf
(
"[UsageCleanup] create_task requested: operator=%d %s"
,
createdBy
,
describeUsageCleanupFilters
(
filters
))
sanitizeUsageCleanupFilters
(
&
filters
)
if
err
:=
s
.
validateFilters
(
filters
);
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] create_task rejected: operator=%d err=%v %s"
,
createdBy
,
err
,
describeUsageCleanupFilters
(
filters
))
return
nil
,
err
}
task
:=
&
UsageCleanupTask
{
Status
:
UsageCleanupStatusPending
,
Filters
:
filters
,
CreatedBy
:
createdBy
,
}
if
err
:=
s
.
repo
.
CreateTask
(
ctx
,
task
);
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] create_task persist failed: operator=%d err=%v %s"
,
createdBy
,
err
,
describeUsageCleanupFilters
(
filters
))
return
nil
,
fmt
.
Errorf
(
"create cleanup task: %w"
,
err
)
}
log
.
Printf
(
"[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s"
,
task
.
ID
,
createdBy
,
task
.
Status
,
task
.
DeletedRows
,
describeUsageCleanupFilters
(
filters
))
go
s
.
runOnce
()
return
task
,
nil
}
func
(
s
*
UsageCleanupService
)
runOnce
()
{
svc
:=
s
if
svc
==
nil
{
return
}
if
!
atomic
.
CompareAndSwapInt32
(
&
svc
.
running
,
0
,
1
)
{
log
.
Printf
(
"[UsageCleanup] run_once skipped: already_running=true"
)
return
}
defer
atomic
.
StoreInt32
(
&
svc
.
running
,
0
)
parent
:=
context
.
Background
()
if
svc
.
workerCtx
!=
nil
{
parent
=
svc
.
workerCtx
}
ctx
,
cancel
:=
context
.
WithTimeout
(
parent
,
svc
.
taskTimeout
())
defer
cancel
()
task
,
err
:=
svc
.
repo
.
ClaimNextPendingTask
(
ctx
,
int64
(
svc
.
taskTimeout
()
.
Seconds
()))
if
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] claim pending task failed: %v"
,
err
)
return
}
if
task
==
nil
{
log
.
Printf
(
"[UsageCleanup] run_once done: no_task=true"
)
return
}
log
.
Printf
(
"[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s"
,
task
.
ID
,
task
.
Status
,
task
.
CreatedBy
,
task
.
DeletedRows
,
describeUsageCleanupFilters
(
task
.
Filters
))
svc
.
executeTask
(
ctx
,
task
)
}
func
(
s
*
UsageCleanupService
)
executeTask
(
ctx
context
.
Context
,
task
*
UsageCleanupTask
)
{
if
task
==
nil
{
return
}
batchSize
:=
s
.
batchSize
()
deletedTotal
:=
task
.
DeletedRows
start
:=
time
.
Now
()
log
.
Printf
(
"[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s"
,
task
.
ID
,
batchSize
,
deletedTotal
,
describeUsageCleanupFilters
(
task
.
Filters
))
var
batchNum
int
for
{
if
ctx
!=
nil
&&
ctx
.
Err
()
!=
nil
{
log
.
Printf
(
"[UsageCleanup] task interrupted: task=%d err=%v"
,
task
.
ID
,
ctx
.
Err
())
return
}
canceled
,
err
:=
s
.
isTaskCanceled
(
ctx
,
task
.
ID
)
if
err
!=
nil
{
s
.
markTaskFailed
(
task
.
ID
,
deletedTotal
,
err
)
return
}
if
canceled
{
log
.
Printf
(
"[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s"
,
task
.
ID
,
deletedTotal
,
time
.
Since
(
start
))
return
}
batchNum
++
deleted
,
err
:=
s
.
repo
.
DeleteUsageLogsBatch
(
ctx
,
task
.
Filters
,
batchSize
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
context
.
Canceled
)
||
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
log
.
Printf
(
"[UsageCleanup] task interrupted: task=%d err=%v"
,
task
.
ID
,
err
)
return
}
s
.
markTaskFailed
(
task
.
ID
,
deletedTotal
,
err
)
return
}
deletedTotal
+=
deleted
if
deleted
>
0
{
updateCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
if
err
:=
s
.
repo
.
UpdateTaskProgress
(
updateCtx
,
task
.
ID
,
deletedTotal
);
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v"
,
task
.
ID
,
deletedTotal
,
err
)
}
cancel
()
}
if
batchNum
<=
3
||
batchNum
%
20
==
0
||
deleted
<
int64
(
batchSize
)
{
log
.
Printf
(
"[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d"
,
task
.
ID
,
batchNum
,
deleted
,
deletedTotal
)
}
if
deleted
==
0
||
deleted
<
int64
(
batchSize
)
{
break
}
}
updateCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
repo
.
MarkTaskSucceeded
(
updateCtx
,
task
.
ID
,
deletedTotal
);
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] update task succeeded failed: task=%d err=%v"
,
task
.
ID
,
err
)
}
else
{
log
.
Printf
(
"[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s"
,
task
.
ID
,
deletedTotal
,
time
.
Since
(
start
))
}
if
s
.
dashboard
!=
nil
{
if
err
:=
s
.
dashboard
.
TriggerRecomputeRange
(
task
.
Filters
.
StartTime
,
task
.
Filters
.
EndTime
);
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v"
,
task
.
ID
,
err
)
}
else
{
log
.
Printf
(
"[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s"
,
task
.
ID
,
task
.
Filters
.
StartTime
.
UTC
()
.
Format
(
time
.
RFC3339
),
task
.
Filters
.
EndTime
.
UTC
()
.
Format
(
time
.
RFC3339
))
}
}
}
func
(
s
*
UsageCleanupService
)
markTaskFailed
(
taskID
int64
,
deletedRows
int64
,
err
error
)
{
msg
:=
strings
.
TrimSpace
(
err
.
Error
())
if
len
(
msg
)
>
500
{
msg
=
msg
[
:
500
]
}
log
.
Printf
(
"[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s"
,
taskID
,
deletedRows
,
msg
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
updateErr
:=
s
.
repo
.
MarkTaskFailed
(
ctx
,
taskID
,
deletedRows
,
msg
);
updateErr
!=
nil
{
log
.
Printf
(
"[UsageCleanup] update task failed failed: task=%d err=%v"
,
taskID
,
updateErr
)
}
}
func
(
s
*
UsageCleanupService
)
isTaskCanceled
(
ctx
context
.
Context
,
taskID
int64
)
(
bool
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
false
,
fmt
.
Errorf
(
"cleanup service not ready"
)
}
checkCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
status
,
err
:=
s
.
repo
.
GetTaskStatus
(
checkCtx
,
taskID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
false
,
nil
}
return
false
,
err
}
if
status
==
UsageCleanupStatusCanceled
{
log
.
Printf
(
"[UsageCleanup] task cancel detected: task=%d"
,
taskID
)
}
return
status
==
UsageCleanupStatusCanceled
,
nil
}
func
(
s
*
UsageCleanupService
)
validateFilters
(
filters
UsageCleanupFilters
)
error
{
if
filters
.
StartTime
.
IsZero
()
||
filters
.
EndTime
.
IsZero
()
{
return
infraerrors
.
BadRequest
(
"USAGE_CLEANUP_MISSING_RANGE"
,
"start_date and end_date are required"
)
}
if
filters
.
EndTime
.
Before
(
filters
.
StartTime
)
{
return
infraerrors
.
BadRequest
(
"USAGE_CLEANUP_INVALID_RANGE"
,
"end_date must be after start_date"
)
}
maxDays
:=
s
.
maxRangeDays
()
if
maxDays
>
0
{
delta
:=
filters
.
EndTime
.
Sub
(
filters
.
StartTime
)
if
delta
>
time
.
Duration
(
maxDays
)
*
24
*
time
.
Hour
{
return
infraerrors
.
BadRequest
(
"USAGE_CLEANUP_RANGE_TOO_LARGE"
,
fmt
.
Sprintf
(
"date range exceeds %d days"
,
maxDays
))
}
}
return
nil
}
func
(
s
*
UsageCleanupService
)
CancelTask
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
fmt
.
Errorf
(
"cleanup service not ready"
)
}
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
UsageCleanup
.
Enabled
{
return
infraerrors
.
New
(
http
.
StatusServiceUnavailable
,
"USAGE_CLEANUP_DISABLED"
,
"usage cleanup is disabled"
)
}
if
canceledBy
<=
0
{
return
infraerrors
.
BadRequest
(
"USAGE_CLEANUP_INVALID_CANCELLER"
,
"invalid canceller"
)
}
status
,
err
:=
s
.
repo
.
GetTaskStatus
(
ctx
,
taskID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
infraerrors
.
New
(
http
.
StatusNotFound
,
"USAGE_CLEANUP_TASK_NOT_FOUND"
,
"cleanup task not found"
)
}
return
err
}
log
.
Printf
(
"[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s"
,
taskID
,
canceledBy
,
status
)
if
status
!=
UsageCleanupStatusPending
&&
status
!=
UsageCleanupStatusRunning
{
return
infraerrors
.
New
(
http
.
StatusConflict
,
"USAGE_CLEANUP_CANCEL_CONFLICT"
,
"cleanup task cannot be canceled in current status"
)
}
ok
,
err
:=
s
.
repo
.
CancelTask
(
ctx
,
taskID
,
canceledBy
)
if
err
!=
nil
{
return
err
}
if
!
ok
{
// 状态可能并发改变
return
infraerrors
.
New
(
http
.
StatusConflict
,
"USAGE_CLEANUP_CANCEL_CONFLICT"
,
"cleanup task cannot be canceled in current status"
)
}
log
.
Printf
(
"[UsageCleanup] cancel_task done: task=%d operator=%d"
,
taskID
,
canceledBy
)
return
nil
}
func
sanitizeUsageCleanupFilters
(
filters
*
UsageCleanupFilters
)
{
if
filters
==
nil
{
return
}
if
filters
.
UserID
!=
nil
&&
*
filters
.
UserID
<=
0
{
filters
.
UserID
=
nil
}
if
filters
.
APIKeyID
!=
nil
&&
*
filters
.
APIKeyID
<=
0
{
filters
.
APIKeyID
=
nil
}
if
filters
.
AccountID
!=
nil
&&
*
filters
.
AccountID
<=
0
{
filters
.
AccountID
=
nil
}
if
filters
.
GroupID
!=
nil
&&
*
filters
.
GroupID
<=
0
{
filters
.
GroupID
=
nil
}
if
filters
.
Model
!=
nil
{
model
:=
strings
.
TrimSpace
(
*
filters
.
Model
)
if
model
==
""
{
filters
.
Model
=
nil
}
else
{
filters
.
Model
=
&
model
}
}
if
filters
.
BillingType
!=
nil
&&
*
filters
.
BillingType
<
0
{
filters
.
BillingType
=
nil
}
}
func
(
s
*
UsageCleanupService
)
maxRangeDays
()
int
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
31
}
if
s
.
cfg
.
UsageCleanup
.
MaxRangeDays
>
0
{
return
s
.
cfg
.
UsageCleanup
.
MaxRangeDays
}
return
31
}
func
(
s
*
UsageCleanupService
)
batchSize
()
int
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
5000
}
if
s
.
cfg
.
UsageCleanup
.
BatchSize
>
0
{
return
s
.
cfg
.
UsageCleanup
.
BatchSize
}
return
5000
}
func
(
s
*
UsageCleanupService
)
workerInterval
()
time
.
Duration
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
10
*
time
.
Second
}
if
s
.
cfg
.
UsageCleanup
.
WorkerIntervalSeconds
>
0
{
return
time
.
Duration
(
s
.
cfg
.
UsageCleanup
.
WorkerIntervalSeconds
)
*
time
.
Second
}
return
10
*
time
.
Second
}
func
(
s
*
UsageCleanupService
)
taskTimeout
()
time
.
Duration
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
30
*
time
.
Minute
}
if
s
.
cfg
.
UsageCleanup
.
TaskTimeoutSeconds
>
0
{
return
time
.
Duration
(
s
.
cfg
.
UsageCleanup
.
TaskTimeoutSeconds
)
*
time
.
Second
}
return
30
*
time
.
Minute
}
backend/internal/service/usage_cleanup_service_test.go
0 → 100644
View file @
0170d19f
package
service
import
(
"context"
"database/sql"
"errors"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type
cleanupDeleteResponse
struct
{
deleted
int64
err
error
}
type
cleanupDeleteCall
struct
{
filters
UsageCleanupFilters
limit
int
}
type
cleanupMarkCall
struct
{
taskID
int64
deletedRows
int64
errMsg
string
}
type
cleanupRepoStub
struct
{
mu
sync
.
Mutex
created
[]
*
UsageCleanupTask
createErr
error
listTasks
[]
UsageCleanupTask
listResult
*
pagination
.
PaginationResult
listErr
error
claimQueue
[]
*
UsageCleanupTask
claimErr
error
deleteQueue
[]
cleanupDeleteResponse
deleteCalls
[]
cleanupDeleteCall
markSucceeded
[]
cleanupMarkCall
markFailed
[]
cleanupMarkCall
statusByID
map
[
int64
]
string
statusErr
error
progressCalls
[]
cleanupMarkCall
updateErr
error
cancelCalls
[]
int64
cancelErr
error
cancelResult
*
bool
markFailedErr
error
}
type
dashboardRepoStub
struct
{
recomputeErr
error
}
func
(
s
*
dashboardRepoStub
)
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardRepoStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
s
.
recomputeErr
}
func
(
s
*
dashboardRepoStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
return
time
.
Time
{},
nil
}
func
(
s
*
dashboardRepoStub
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardRepoStub
)
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardRepoStub
)
CleanupUsageLogs
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardRepoStub
)
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
{
return
nil
}
func
(
s
*
cleanupRepoStub
)
CreateTask
(
ctx
context
.
Context
,
task
*
UsageCleanupTask
)
error
{
if
task
==
nil
{
return
nil
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
createErr
!=
nil
{
return
s
.
createErr
}
if
task
.
ID
==
0
{
task
.
ID
=
int64
(
len
(
s
.
created
)
+
1
)
}
if
task
.
CreatedAt
.
IsZero
()
{
task
.
CreatedAt
=
time
.
Now
()
.
UTC
()
}
if
task
.
UpdatedAt
.
IsZero
()
{
task
.
UpdatedAt
=
task
.
CreatedAt
}
clone
:=
*
task
s
.
created
=
append
(
s
.
created
,
&
clone
)
return
nil
}
func
(
s
*
cleanupRepoStub
)
ListTasks
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
UsageCleanupTask
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
listTasks
,
s
.
listResult
,
s
.
listErr
}
func
(
s
*
cleanupRepoStub
)
ClaimNextPendingTask
(
ctx
context
.
Context
,
staleRunningAfterSeconds
int64
)
(
*
UsageCleanupTask
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
claimErr
!=
nil
{
return
nil
,
s
.
claimErr
}
if
len
(
s
.
claimQueue
)
==
0
{
return
nil
,
nil
}
task
:=
s
.
claimQueue
[
0
]
s
.
claimQueue
=
s
.
claimQueue
[
1
:
]
if
s
.
statusByID
==
nil
{
s
.
statusByID
=
map
[
int64
]
string
{}
}
s
.
statusByID
[
task
.
ID
]
=
UsageCleanupStatusRunning
return
task
,
nil
}
func
(
s
*
cleanupRepoStub
)
GetTaskStatus
(
ctx
context
.
Context
,
taskID
int64
)
(
string
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
statusErr
!=
nil
{
return
""
,
s
.
statusErr
}
if
s
.
statusByID
==
nil
{
return
""
,
sql
.
ErrNoRows
}
status
,
ok
:=
s
.
statusByID
[
taskID
]
if
!
ok
{
return
""
,
sql
.
ErrNoRows
}
return
status
,
nil
}
func
(
s
*
cleanupRepoStub
)
UpdateTaskProgress
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
progressCalls
=
append
(
s
.
progressCalls
,
cleanupMarkCall
{
taskID
:
taskID
,
deletedRows
:
deletedRows
})
if
s
.
updateErr
!=
nil
{
return
s
.
updateErr
}
return
nil
}
func
(
s
*
cleanupRepoStub
)
CancelTask
(
ctx
context
.
Context
,
taskID
int64
,
canceledBy
int64
)
(
bool
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
cancelCalls
=
append
(
s
.
cancelCalls
,
taskID
)
if
s
.
cancelErr
!=
nil
{
return
false
,
s
.
cancelErr
}
if
s
.
cancelResult
!=
nil
{
ok
:=
*
s
.
cancelResult
if
ok
{
if
s
.
statusByID
==
nil
{
s
.
statusByID
=
map
[
int64
]
string
{}
}
s
.
statusByID
[
taskID
]
=
UsageCleanupStatusCanceled
}
return
ok
,
nil
}
if
s
.
statusByID
==
nil
{
s
.
statusByID
=
map
[
int64
]
string
{}
}
status
:=
s
.
statusByID
[
taskID
]
if
status
!=
UsageCleanupStatusPending
&&
status
!=
UsageCleanupStatusRunning
{
return
false
,
nil
}
s
.
statusByID
[
taskID
]
=
UsageCleanupStatusCanceled
return
true
,
nil
}
func
(
s
*
cleanupRepoStub
)
MarkTaskSucceeded
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
markSucceeded
=
append
(
s
.
markSucceeded
,
cleanupMarkCall
{
taskID
:
taskID
,
deletedRows
:
deletedRows
})
if
s
.
statusByID
==
nil
{
s
.
statusByID
=
map
[
int64
]
string
{}
}
s
.
statusByID
[
taskID
]
=
UsageCleanupStatusSucceeded
return
nil
}
func
(
s
*
cleanupRepoStub
)
MarkTaskFailed
(
ctx
context
.
Context
,
taskID
int64
,
deletedRows
int64
,
errorMsg
string
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
markFailed
=
append
(
s
.
markFailed
,
cleanupMarkCall
{
taskID
:
taskID
,
deletedRows
:
deletedRows
,
errMsg
:
errorMsg
})
if
s
.
statusByID
==
nil
{
s
.
statusByID
=
map
[
int64
]
string
{}
}
s
.
statusByID
[
taskID
]
=
UsageCleanupStatusFailed
if
s
.
markFailedErr
!=
nil
{
return
s
.
markFailedErr
}
return
nil
}
func
(
s
*
cleanupRepoStub
)
DeleteUsageLogsBatch
(
ctx
context
.
Context
,
filters
UsageCleanupFilters
,
limit
int
)
(
int64
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
deleteCalls
=
append
(
s
.
deleteCalls
,
cleanupDeleteCall
{
filters
:
filters
,
limit
:
limit
})
if
len
(
s
.
deleteQueue
)
==
0
{
return
0
,
nil
}
resp
:=
s
.
deleteQueue
[
0
]
s
.
deleteQueue
=
s
.
deleteQueue
[
1
:
]
return
resp
.
deleted
,
resp
.
err
}
func
TestUsageCleanupServiceCreateTaskSanitizeFilters
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
MaxRangeDays
:
31
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
userID
:=
int64
(
-
1
)
apiKeyID
:=
int64
(
10
)
model
:=
" gpt-4 "
billingType
:=
int8
(
-
2
)
filters
:=
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
UserID
:
&
userID
,
APIKeyID
:
&
apiKeyID
,
Model
:
&
model
,
BillingType
:
&
billingType
,
}
task
,
err
:=
svc
.
CreateTask
(
context
.
Background
(),
filters
,
9
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
UsageCleanupStatusPending
,
task
.
Status
)
require
.
Nil
(
t
,
task
.
Filters
.
UserID
)
require
.
NotNil
(
t
,
task
.
Filters
.
APIKeyID
)
require
.
Equal
(
t
,
apiKeyID
,
*
task
.
Filters
.
APIKeyID
)
require
.
NotNil
(
t
,
task
.
Filters
.
Model
)
require
.
Equal
(
t
,
"gpt-4"
,
*
task
.
Filters
.
Model
)
require
.
Nil
(
t
,
task
.
Filters
.
BillingType
)
require
.
Equal
(
t
,
int64
(
9
),
task
.
CreatedBy
)
}
func
TestUsageCleanupServiceCreateTaskInvalidCreator
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
filters
:=
UsageCleanupFilters
{
StartTime
:
time
.
Now
(),
EndTime
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
}
_
,
err
:=
svc
.
CreateTask
(
context
.
Background
(),
filters
,
0
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"USAGE_CLEANUP_INVALID_CREATOR"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCreateTaskDisabled
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
false
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
filters
:=
UsageCleanupFilters
{
StartTime
:
time
.
Now
(),
EndTime
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
}
_
,
err
:=
svc
.
CreateTask
(
context
.
Background
(),
filters
,
1
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
infraerrors
.
Code
(
err
))
require
.
Equal
(
t
,
"USAGE_CLEANUP_DISABLED"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCreateTaskRangeTooLarge
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
MaxRangeDays
:
1
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
48
*
time
.
Hour
)
filters
:=
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}
_
,
err
:=
svc
.
CreateTask
(
context
.
Background
(),
filters
,
1
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"USAGE_CLEANUP_RANGE_TOO_LARGE"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCreateTaskMissingRange
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
_
,
err
:=
svc
.
CreateTask
(
context
.
Background
(),
UsageCleanupFilters
{},
1
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"USAGE_CLEANUP_MISSING_RANGE"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCreateTaskRepoError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
createErr
:
errors
.
New
(
"db down"
)}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
filters
:=
UsageCleanupFilters
{
StartTime
:
time
.
Now
(),
EndTime
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
}
_
,
err
:=
svc
.
CreateTask
(
context
.
Background
(),
filters
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"create cleanup task"
)
}
func
TestUsageCleanupServiceRunOnceSuccess
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
2
*
time
.
Hour
)
repo
:=
&
cleanupRepoStub
{
claimQueue
:
[]
*
UsageCleanupTask
{
{
ID
:
5
,
Filters
:
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
}},
},
deleteQueue
:
[]
cleanupDeleteResponse
{
{
deleted
:
2
},
{
deleted
:
2
},
{
deleted
:
1
},
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
,
TaskTimeoutSeconds
:
30
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
svc
.
runOnce
()
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
deleteCalls
,
3
)
require
.
Equal
(
t
,
2
,
repo
.
deleteCalls
[
0
]
.
limit
)
require
.
True
(
t
,
repo
.
deleteCalls
[
0
]
.
filters
.
StartTime
.
Equal
(
start
))
require
.
True
(
t
,
repo
.
deleteCalls
[
0
]
.
filters
.
EndTime
.
Equal
(
end
))
require
.
Len
(
t
,
repo
.
markSucceeded
,
1
)
require
.
Empty
(
t
,
repo
.
markFailed
)
require
.
Equal
(
t
,
int64
(
5
),
repo
.
markSucceeded
[
0
]
.
taskID
)
require
.
Equal
(
t
,
int64
(
5
),
repo
.
markSucceeded
[
0
]
.
deletedRows
)
require
.
Equal
(
t
,
2
,
repo
.
deleteCalls
[
0
]
.
limit
)
require
.
Equal
(
t
,
start
,
repo
.
deleteCalls
[
0
]
.
filters
.
StartTime
)
require
.
Equal
(
t
,
end
,
repo
.
deleteCalls
[
0
]
.
filters
.
EndTime
)
}
func
TestUsageCleanupServiceRunOnceClaimError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
claimErr
:
errors
.
New
(
"claim failed"
)}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
svc
.
runOnce
()
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Empty
(
t
,
repo
.
markSucceeded
)
require
.
Empty
(
t
,
repo
.
markFailed
)
}
func
TestUsageCleanupServiceRunOnceAlreadyRunning
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
svc
.
running
=
1
svc
.
runOnce
()
}
func
TestUsageCleanupServiceExecuteTaskFailed
(
t
*
testing
.
T
)
{
longMsg
:=
strings
.
Repeat
(
"x"
,
600
)
repo
:=
&
cleanupRepoStub
{
deleteQueue
:
[]
cleanupDeleteResponse
{
{
err
:
errors
.
New
(
longMsg
)},
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
3
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
11
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
(),
EndTime
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
markFailed
,
1
)
require
.
Equal
(
t
,
int64
(
11
),
repo
.
markFailed
[
0
]
.
taskID
)
require
.
Equal
(
t
,
500
,
len
(
repo
.
markFailed
[
0
]
.
errMsg
))
}
func
TestUsageCleanupServiceExecuteTaskProgressError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
deleteQueue
:
[]
cleanupDeleteResponse
{
{
deleted
:
2
},
{
deleted
:
0
},
},
updateErr
:
errors
.
New
(
"update failed"
),
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
8
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
markSucceeded
,
1
)
require
.
Empty
(
t
,
repo
.
markFailed
)
require
.
Len
(
t
,
repo
.
progressCalls
,
1
)
}
func
TestUsageCleanupServiceExecuteTaskDeleteCanceled
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
deleteQueue
:
[]
cleanupDeleteResponse
{
{
err
:
context
.
Canceled
},
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
12
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Empty
(
t
,
repo
.
markSucceeded
)
require
.
Empty
(
t
,
repo
.
markFailed
)
}
func
TestUsageCleanupServiceExecuteTaskContextCanceled
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
9
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
svc
.
executeTask
(
ctx
,
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Empty
(
t
,
repo
.
markSucceeded
)
require
.
Empty
(
t
,
repo
.
markFailed
)
require
.
Empty
(
t
,
repo
.
deleteCalls
)
}
func
TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
deleteQueue
:
[]
cleanupDeleteResponse
{
{
err
:
errors
.
New
(
"boom"
)},
},
markFailedErr
:
errors
.
New
(
"update failed"
),
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
13
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
markFailed
,
1
)
require
.
Equal
(
t
,
int64
(
13
),
repo
.
markFailed
[
0
]
.
taskID
)
}
func
TestUsageCleanupServiceExecuteTaskDashboardRecomputeError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
deleteQueue
:
[]
cleanupDeleteResponse
{
{
deleted
:
0
},
},
}
dashboard
:=
NewDashboardAggregationService
(
&
dashboardRepoStub
{},
nil
,
&
config
.
Config
{
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
false
},
})
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
dashboard
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
14
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
markSucceeded
,
1
)
}
func
TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
deleteQueue
:
[]
cleanupDeleteResponse
{
{
deleted
:
0
},
},
}
dashboard
:=
NewDashboardAggregationService
(
&
dashboardRepoStub
{},
nil
,
&
config
.
Config
{
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
},
})
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
dashboard
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
15
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Len
(
t
,
repo
.
markSucceeded
,
1
)
}
func
TestUsageCleanupServiceExecuteTaskCanceled
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusByID
:
map
[
int64
]
string
{
3
:
UsageCleanupStatusCanceled
,
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
BatchSize
:
2
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
task
:=
&
UsageCleanupTask
{
ID
:
3
,
Filters
:
UsageCleanupFilters
{
StartTime
:
time
.
Now
()
.
UTC
(),
EndTime
:
time
.
Now
()
.
UTC
()
.
Add
(
time
.
Hour
),
},
}
svc
.
executeTask
(
context
.
Background
(),
task
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Empty
(
t
,
repo
.
deleteCalls
)
require
.
Empty
(
t
,
repo
.
markSucceeded
)
require
.
Empty
(
t
,
repo
.
markFailed
)
}
func
TestUsageCleanupServiceCancelTaskSuccess
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusByID
:
map
[
int64
]
string
{
5
:
UsageCleanupStatusPending
,
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
5
,
9
)
require
.
NoError
(
t
,
err
)
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Equal
(
t
,
UsageCleanupStatusCanceled
,
repo
.
statusByID
[
5
])
require
.
Len
(
t
,
repo
.
cancelCalls
,
1
)
}
func
TestUsageCleanupServiceCancelTaskDisabled
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
false
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
1
,
2
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
infraerrors
.
Code
(
err
))
require
.
Equal
(
t
,
"USAGE_CLEANUP_DISABLED"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCancelTaskNotFound
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
infraerrors
.
Code
(
err
))
require
.
Equal
(
t
,
"USAGE_CLEANUP_TASK_NOT_FOUND"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCancelTaskStatusError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusErr
:
errors
.
New
(
"status broken"
)}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
7
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"status broken"
)
}
func
TestUsageCleanupServiceCancelTaskConflict
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusByID
:
map
[
int64
]
string
{
7
:
UsageCleanupStatusSucceeded
,
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
7
,
1
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
infraerrors
.
Code
(
err
))
require
.
Equal
(
t
,
"USAGE_CLEANUP_CANCEL_CONFLICT"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCancelTaskRepoConflict
(
t
*
testing
.
T
)
{
shouldCancel
:=
false
repo
:=
&
cleanupRepoStub
{
statusByID
:
map
[
int64
]
string
{
7
:
UsageCleanupStatusPending
,
},
cancelResult
:
&
shouldCancel
,
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
7
,
1
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
infraerrors
.
Code
(
err
))
require
.
Equal
(
t
,
"USAGE_CLEANUP_CANCEL_CONFLICT"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceCancelTaskRepoError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusByID
:
map
[
int64
]
string
{
7
:
UsageCleanupStatusPending
,
},
cancelErr
:
errors
.
New
(
"cancel failed"
),
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
7
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"cancel failed"
)
}
func
TestUsageCleanupServiceCancelTaskInvalidCanceller
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusByID
:
map
[
int64
]
string
{
7
:
UsageCleanupStatusRunning
,
},
}
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfg
)
err
:=
svc
.
CancelTask
(
context
.
Background
(),
7
,
0
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"USAGE_CLEANUP_INVALID_CANCELLER"
,
infraerrors
.
Reason
(
err
))
}
func
TestUsageCleanupServiceListTasks
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
listTasks
:
[]
UsageCleanupTask
{{
ID
:
1
},
{
ID
:
2
}},
listResult
:
&
pagination
.
PaginationResult
{
Total
:
2
,
Page
:
1
,
PageSize
:
20
,
Pages
:
1
,
},
}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}})
tasks
,
result
,
err
:=
svc
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
tasks
,
2
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Total
)
}
func
TestUsageCleanupServiceListTasksNotReady
(
t
*
testing
.
T
)
{
var
nilSvc
*
UsageCleanupService
_
,
_
,
err
:=
nilSvc
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
Error
(
t
,
err
)
svc
:=
NewUsageCleanupService
(
nil
,
nil
,
nil
,
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}})
_
,
_
,
err
=
svc
.
ListTasks
(
context
.
Background
(),
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
})
require
.
Error
(
t
,
err
)
}
func
TestUsageCleanupServiceDefaultsAndLifecycle
(
t
*
testing
.
T
)
{
var
nilSvc
*
UsageCleanupService
require
.
Equal
(
t
,
31
,
nilSvc
.
maxRangeDays
())
require
.
Equal
(
t
,
5000
,
nilSvc
.
batchSize
())
require
.
Equal
(
t
,
10
*
time
.
Second
,
nilSvc
.
workerInterval
())
require
.
Equal
(
t
,
30
*
time
.
Minute
,
nilSvc
.
taskTimeout
())
nilSvc
.
Start
()
nilSvc
.
Stop
()
repo
:=
&
cleanupRepoStub
{}
cfgDisabled
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
false
}}
svcDisabled
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
cfgDisabled
)
svcDisabled
.
Start
()
svcDisabled
.
Stop
()
timingWheel
,
err
:=
NewTimingWheelService
()
require
.
NoError
(
t
,
err
)
cfg
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
,
WorkerIntervalSeconds
:
5
}}
svc
:=
NewUsageCleanupService
(
repo
,
timingWheel
,
nil
,
cfg
)
require
.
Equal
(
t
,
5
*
time
.
Second
,
svc
.
workerInterval
())
svc
.
Start
()
svc
.
Stop
()
cfgFallback
:=
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}}
svcFallback
:=
NewUsageCleanupService
(
repo
,
timingWheel
,
nil
,
cfgFallback
)
require
.
Equal
(
t
,
31
,
svcFallback
.
maxRangeDays
())
require
.
Equal
(
t
,
5000
,
svcFallback
.
batchSize
())
require
.
Equal
(
t
,
10
*
time
.
Second
,
svcFallback
.
workerInterval
())
svcMissingDeps
:=
NewUsageCleanupService
(
nil
,
nil
,
nil
,
cfgFallback
)
svcMissingDeps
.
Start
()
}
func
TestSanitizeUsageCleanupFiltersModelEmpty
(
t
*
testing
.
T
)
{
model
:=
" "
apiKeyID
:=
int64
(
-
5
)
accountID
:=
int64
(
-
1
)
groupID
:=
int64
(
-
2
)
filters
:=
UsageCleanupFilters
{
UserID
:
&
apiKeyID
,
APIKeyID
:
&
apiKeyID
,
AccountID
:
&
accountID
,
GroupID
:
&
groupID
,
Model
:
&
model
,
}
sanitizeUsageCleanupFilters
(
&
filters
)
require
.
Nil
(
t
,
filters
.
UserID
)
require
.
Nil
(
t
,
filters
.
APIKeyID
)
require
.
Nil
(
t
,
filters
.
AccountID
)
require
.
Nil
(
t
,
filters
.
GroupID
)
require
.
Nil
(
t
,
filters
.
Model
)
}
func
TestDescribeUsageCleanupFiltersAllFields
(
t
*
testing
.
T
)
{
start
:=
time
.
Date
(
2024
,
2
,
1
,
10
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
2
*
time
.
Hour
)
userID
:=
int64
(
1
)
apiKeyID
:=
int64
(
2
)
accountID
:=
int64
(
3
)
groupID
:=
int64
(
4
)
model
:=
" gpt-4 "
stream
:=
true
billingType
:=
int8
(
2
)
filters
:=
UsageCleanupFilters
{
StartTime
:
start
,
EndTime
:
end
,
UserID
:
&
userID
,
APIKeyID
:
&
apiKeyID
,
AccountID
:
&
accountID
,
GroupID
:
&
groupID
,
Model
:
&
model
,
Stream
:
&
stream
,
BillingType
:
&
billingType
,
}
desc
:=
describeUsageCleanupFilters
(
filters
)
require
.
Equal
(
t
,
"start=2024-02-01T10:00:00Z end=2024-02-01T12:00:00Z user_id=1 api_key_id=2 account_id=3 group_id=4 model=gpt-4 stream=true billing_type=2"
,
desc
)
}
func
TestUsageCleanupServiceIsTaskCanceledNotFound
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}})
canceled
,
err
:=
svc
.
isTaskCanceled
(
context
.
Background
(),
9
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
canceled
)
}
func
TestUsageCleanupServiceIsTaskCanceledError
(
t
*
testing
.
T
)
{
repo
:=
&
cleanupRepoStub
{
statusErr
:
errors
.
New
(
"status err"
)}
svc
:=
NewUsageCleanupService
(
repo
,
nil
,
nil
,
&
config
.
Config
{
UsageCleanup
:
config
.
UsageCleanupConfig
{
Enabled
:
true
}})
_
,
err
:=
svc
.
isTaskCanceled
(
context
.
Background
(),
9
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"status err"
)
}
backend/internal/service/user_service.go
View file @
0170d19f
...
...
@@ -38,6 +38,11 @@ type UserRepository interface {
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
// TOTP 相关方法
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
}
// UpdateProfileRequest 更新用户资料请求
...
...
backend/internal/service/user_subscription_port.go
View file @
0170d19f
...
...
@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
UserSubscription
,
error
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
UserSubscription
,
error
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
ExtendExpiry
(
ctx
context
.
Context
,
subscriptionID
int64
,
newExpiresAt
time
.
Time
)
error
...
...
backend/internal/service/wire.go
View file @
0170d19f
package
service
import
(
"context"
"database/sql"
"time"
...
...
@@ -43,9 +44,10 @@ func ProvideTokenRefreshService(
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cacheInvalidator
TokenCacheInvalidator
,
schedulerCache
SchedulerCache
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cacheInvalidator
,
cfg
)
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cacheInvalidator
,
schedulerCache
,
cfg
)
svc
.
Start
()
return
svc
}
...
...
@@ -57,6 +59,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
return
svc
}
// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
func
ProvideUsageCleanupService
(
repo
UsageCleanupRepository
,
timingWheel
*
TimingWheelService
,
dashboardAgg
*
DashboardAggregationService
,
cfg
*
config
.
Config
)
*
UsageCleanupService
{
svc
:=
NewUsageCleanupService
(
repo
,
timingWheel
,
dashboardAgg
,
cfg
)
svc
.
Start
()
return
svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func
ProvideAccountExpiryService
(
accountRepo
AccountRepository
)
*
AccountExpiryService
{
svc
:=
NewAccountExpiryService
(
accountRepo
,
time
.
Minute
)
...
...
@@ -64,6 +73,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
return
svc
}
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
func
ProvideSubscriptionExpiryService
(
userSubRepo
UserSubscriptionRepository
)
*
SubscriptionExpiryService
{
svc
:=
NewSubscriptionExpiryService
(
userSubRepo
,
time
.
Minute
)
svc
.
Start
()
return
svc
}
// ProvideTimingWheelService creates and starts TimingWheelService
func
ProvideTimingWheelService
()
(
*
TimingWheelService
,
error
)
{
svc
,
err
:=
NewTimingWheelService
()
...
...
@@ -189,6 +205,8 @@ func ProvideOpsScheduledReportService(
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func
ProvideAPIKeyAuthCacheInvalidator
(
apiKeyService
*
APIKeyService
)
APIKeyAuthCacheInvalidator
{
// Start Pub/Sub subscriber for L1 cache invalidation across instances
apiKeyService
.
StartAuthCacheInvalidationSubscriber
(
context
.
Background
())
return
apiKeyService
}
...
...
@@ -209,6 +227,7 @@ var ProviderSet = wire.NewSet(
ProvidePricingService
,
NewBillingService
,
NewBillingCacheService
,
NewAnnouncementService
,
NewAdminService
,
NewGatewayService
,
NewOpenAIGatewayService
,
...
...
@@ -246,10 +265,13 @@ var ProviderSet = wire.NewSet(
ProvideUpdateService
,
ProvideTokenRefreshService
,
ProvideAccountExpiryService
,
ProvideSubscriptionExpiryService
,
ProvideTimingWheelService
,
ProvideDashboardAggregationService
,
ProvideUsageCleanupService
,
ProvideDeferredService
,
NewAntigravityQuotaFetcher
,
NewUserAttributeService
,
NewUsageCache
,
NewTotpService
,
)
backend/internal/setup/cli.go
View file @
0170d19f
...
...
@@ -149,6 +149,8 @@ func RunCLI() error {
fmt
.
Println
(
" Invalid Redis DB. Must be between 0 and 15."
)
}
cfg
.
Redis
.
EnableTLS
=
promptConfirm
(
reader
,
"Enable Redis TLS?"
)
fmt
.
Println
()
fmt
.
Print
(
"Testing Redis connection... "
)
if
err
:=
TestRedisConnection
(
&
cfg
.
Redis
);
err
!=
nil
{
...
...
@@ -205,6 +207,7 @@ func RunCLI() error {
fmt
.
Println
(
"── Configuration Summary ──"
)
fmt
.
Printf
(
"Database: %s@%s:%d/%s
\n
"
,
cfg
.
Database
.
User
,
cfg
.
Database
.
Host
,
cfg
.
Database
.
Port
,
cfg
.
Database
.
DBName
)
fmt
.
Printf
(
"Redis: %s:%d
\n
"
,
cfg
.
Redis
.
Host
,
cfg
.
Redis
.
Port
)
fmt
.
Printf
(
"Redis TLS: %s
\n
"
,
map
[
bool
]
string
{
true
:
"enabled"
,
false
:
"disabled"
}[
cfg
.
Redis
.
EnableTLS
])
fmt
.
Printf
(
"Admin: %s
\n
"
,
cfg
.
Admin
.
Email
)
fmt
.
Printf
(
"Server: :%d
\n
"
,
cfg
.
Server
.
Port
)
fmt
.
Println
()
...
...
backend/internal/setup/handler.go
View file @
0170d19f
...
...
@@ -180,6 +180,7 @@ type TestRedisRequest struct {
Port
int
`json:"port" binding:"required"`
Password
string
`json:"password"`
DB
int
`json:"db"`
EnableTLS
bool
`json:"enable_tls"`
}
// testRedis tests Redis connection
...
...
@@ -209,6 +210,7 @@ func testRedis(c *gin.Context) {
Port
:
req
.
Port
,
Password
:
req
.
Password
,
DB
:
req
.
DB
,
EnableTLS
:
req
.
EnableTLS
,
}
if
err
:=
TestRedisConnection
(
cfg
);
err
!=
nil
{
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment