Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0b746501
Commit
0b746501
authored
Apr 16, 2026
by
陈曦
Browse files
1. merge upstream v0.1.113 2.提交migration相关文件
parents
45061102
be7551b9
Changes
225
Show whitespace changes
Inline
Side-by-side
backend/internal/service/account_usage_service_test.go
View file @
0b746501
...
...
@@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
}
}
func
TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt
(
t
*
testing
.
T
)
{
t
.
Parallel
()
headers
:=
make
(
http
.
Header
)
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"604800"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"18000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
updates
,
resetAt
,
err
:=
extractOpenAICodexProbeSnapshot
(
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
headers
})
if
err
!=
nil
{
t
.
Fatalf
(
"extractOpenAICodexProbeSnapshot() error = %v"
,
err
)
}
if
len
(
updates
)
==
0
{
t
.
Fatal
(
"expected codex probe updates from 429 headers"
)
}
if
resetAt
==
nil
{
t
.
Fatal
(
"expected resetAt from exhausted codex headers"
)
}
}
func
TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit
(
t
*
testing
.
T
)
{
func
TestAccountUsageService_PersistOpenAICodexProbeSnapshotOnlyUpdatesExtra
(
t
*
testing
.
T
)
{
t
.
Parallel
()
repo
:=
&
accountUsageCodexProbeRepo
{
...
...
@@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
}
svc
:=
&
AccountUsageService
{
accountRepo
:
repo
}
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
svc
.
persistOpenAICodexProbeSnapshot
(
321
,
map
[
string
]
any
{
"codex_7d_used_percent"
:
100.0
,
"codex_7d_reset_at"
:
resetAt
.
Format
(
time
.
RFC3339
),
}
,
&
resetAt
)
"codex_7d_reset_at"
:
time
.
Now
()
.
Add
(
2
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
.
Format
(
time
.
RFC3339
),
})
select
{
case
updates
:=
<-
repo
.
updateExtraCh
:
...
...
@@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
t
.
Fatalf
(
"codex_7d_used_percent = %v, want 100"
,
got
)
}
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"
waiting for
codex
probe
extra
persistence timed out
"
)
t
.
Fatal
(
"
等待
codex
探测快照写入
extra
超时
"
)
}
select
{
case
got
:=
<-
repo
.
rateLimitCh
:
if
got
.
Before
(
resetAt
.
Add
(
-
time
.
Second
))
||
got
.
After
(
resetAt
.
Add
(
time
.
Second
))
{
t
.
Fatalf
(
"rate limit resetAt = %v, want around %v"
,
got
,
resetAt
)
t
.
Fatalf
(
"不应将探测快照写入运行时限流状态: %v"
,
got
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"waiting for codex probe rate limit persistence timed out"
)
}
func
TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit
(
t
*
testing
.
T
)
{
t
.
Parallel
()
resetAt
:=
time
.
Now
()
.
Add
(
6
*
24
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
repo
:=
&
accountUsageCodexProbeRepo
{
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
}
svc
:=
&
AccountUsageService
{
accountRepo
:
repo
}
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"codex_5h_used_percent"
:
1.0
,
"codex_5h_reset_at"
:
time
.
Now
()
.
Add
(
2
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
.
Format
(
time
.
RFC3339
),
"codex_7d_used_percent"
:
100.0
,
"codex_7d_reset_at"
:
resetAt
.
Format
(
time
.
RFC3339
),
},
}
usage
,
err
:=
svc
.
getOpenAIUsage
(
context
.
Background
(),
account
)
if
err
!=
nil
{
t
.
Fatalf
(
"getOpenAIUsage() error = %v"
,
err
)
}
if
usage
.
SevenDay
==
nil
||
usage
.
SevenDay
.
Utilization
!=
100.0
{
t
.
Fatalf
(
"预期 7 天用量仍然可见,实际为 %#v"
,
usage
.
SevenDay
)
}
if
account
.
RateLimitResetAt
!=
nil
{
t
.
Fatalf
(
"不应让已耗尽的 codex extra 改写运行时限流状态: %v"
,
account
.
RateLimitResetAt
)
}
select
{
case
got
:=
<-
repo
.
rateLimitCh
:
t
.
Fatalf
(
"不应将已耗尽的 codex extra 持久化为运行时限流状态: %v"
,
got
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
...
...
backend/internal/service/account_websearch_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestGetWebSearchEmulationMode_Enabled
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"enabled"
},
}
require
.
Equal
(
t
,
WebSearchModeEnabled
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_Disabled
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"disabled"
},
}
require
.
Equal
(
t
,
WebSearchModeDisabled
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_Default
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"default"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_UnknownString
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"unknown"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_OldBoolTrue
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
true
},
}
// bool true → tolerant fallback → enabled (not default)
require
.
Equal
(
t
,
WebSearchModeEnabled
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_OldBoolFalse
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
false
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NilAccount
(
t
*
testing
.
T
)
{
var
a
*
Account
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NilExtra
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
nil
,
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_MissingField
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NonAnthropicPlatform
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"enabled"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NonAPIKeyType
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"enabled"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
backend/internal/service/admin_service.go
View file @
0b746501
...
...
@@ -1470,10 +1470,6 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
if
err
!=
nil
{
return
nil
,
0
,
err
}
now
:=
time
.
Now
()
for
i
:=
range
accounts
{
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
&
accounts
[
i
],
now
)
}
return
accounts
,
result
.
Total
,
nil
}
...
...
backend/internal/service/admin_service_apikey_test.go
View file @
0b746501
...
...
@@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
panic
(
"unexpected"
)
}
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type
apiKeyRepoStubForGroupUpdate
struct
{
...
...
@@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
UpdateGroupIDByUserAndGroup
(
context
.
Context
,
int64
,
int64
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
...
...
@@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
GetRateLimitData
(
context
.
Context
,
int64
)
(
*
APIKeyRateLimitData
,
error
)
{
panic
(
"unexpected"
)
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
UpdateGroupIDByUserAndGroup
(
context
.
Context
,
int64
,
int64
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type
groupRepoStubForGroupUpdate
struct
{
...
...
backend/internal/service/admin_service_clear_error_test.go
View file @
0b746501
backend/internal/service/api_key_auth_cache.go
View file @
0b746501
...
...
@@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
Role
string
`json:"role"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
// Balance notification fields (required for CheckBalanceAfterDeduction)
Email
string
`json:"email"`
Username
string
`json:"username"`
BalanceNotifyEnabled
bool
`json:"balance_notify_enabled"`
BalanceNotifyThresholdType
string
`json:"balance_notify_threshold_type"`
BalanceNotifyThreshold
*
float64
`json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails,omitempty"`
TotalRecharged
float64
`json:"total_recharged"`
}
// APIKeyAuthGroupSnapshot 分组快照
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
0b746501
...
...
@@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"log/slog"
"math/rand/v2"
"time"
...
...
@@ -13,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const
apiKeyAuthSnapshotVersion
=
3
const
apiKeyAuthSnapshotVersion
=
5
// v5: added TotalRecharged for percentage threshold
type
apiKeyAuthCacheConfig
struct
{
l1Size
int
...
...
@@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
s
.
authCacheL1
.
Del
(
cacheKey
)
});
err
!=
nil
{
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println
(
"[Service] Warning:
failed to start auth cache invalidation subscriber
:
"
,
err
.
Error
()
)
slog
.
Warn
(
"
failed to start auth cache invalidation subscriber"
,
"
err
or"
,
err
)
}
}
...
...
@@ -224,6 +225,13 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Role
:
apiKey
.
User
.
Role
,
Balance
:
apiKey
.
User
.
Balance
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
Email
:
apiKey
.
User
.
Email
,
Username
:
apiKey
.
User
.
Username
,
BalanceNotifyEnabled
:
apiKey
.
User
.
BalanceNotifyEnabled
,
BalanceNotifyThresholdType
:
apiKey
.
User
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
apiKey
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
apiKey
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
apiKey
.
User
.
TotalRecharged
,
},
}
if
apiKey
.
Group
!=
nil
{
...
...
@@ -279,6 +287,13 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Role
:
snapshot
.
User
.
Role
,
Balance
:
snapshot
.
User
.
Balance
,
Concurrency
:
snapshot
.
User
.
Concurrency
,
Email
:
snapshot
.
User
.
Email
,
Username
:
snapshot
.
User
.
Username
,
BalanceNotifyEnabled
:
snapshot
.
User
.
BalanceNotifyEnabled
,
BalanceNotifyThresholdType
:
snapshot
.
User
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
snapshot
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
snapshot
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
snapshot
.
User
.
TotalRecharged
,
},
}
if
snapshot
.
Group
!=
nil
{
...
...
backend/internal/service/auth_service_register_test.go
View file @
0b746501
...
...
@@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return
nil
}
func
(
s
*
emailCacheStub
)
GetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailCacheStub
)
SetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
DeleteNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
return
nil
,
nil
}
...
...
@@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
return
nil
}
func
(
s
*
emailCacheStub
)
GetNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
s
*
emailCacheStub
)
IncrNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
,
window
time
.
Duration
)
(
int64
,
error
)
{
return
0
,
nil
}
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
...
...
backend/internal/service/balance_notify_check_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
// in-memory settings repo and a non-nil emailService so that the guard-clause
// nil-checks pass. The emailService is intentionally minimal — tests must
// avoid crossing scenarios that would actually dispatch emails.
func
newBalanceNotifyServiceForTest
()
(
*
BalanceNotifyService
,
*
mockSettingRepo
)
{
repo
:=
newMockSettingRepo
()
// EmailService is a concrete type; construct with the same repo so that
// any accidental fallback reads still succeed. Tests should not trigger a
// crossing that reaches SendEmail.
email
:=
NewEmailService
(
repo
,
nil
)
return
NewBalanceNotifyService
(
email
,
repo
,
nil
),
repo
}
// ---------- guard clauses ----------
func
TestCheckBalanceAfterDeduction_NilUser
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
// Should not panic.
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
nil
,
100
,
50
)
}
func
TestCheckBalanceAfterDeduction_UserNotifyDisabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"10"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
false
}
// Even with a crossing, disabled flag short-circuits.
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_GlobalDisabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"false"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
}
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_ThresholdZero
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"0"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
}
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_UserThresholdOverride
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"100"
// global default
customThreshold
:=
5.0
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
,
BalanceNotifyThreshold
:
&
customThreshold
,
}
// User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
// cross 5, so nothing fires (verified by absence of panic).
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_NoCrossingNotFired
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"10"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
}
// 100 -> 95, both remain above threshold=10, no crossing.
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
100
,
5
)
// 5 -> 3, both already below threshold, no crossing (only fires on first
// cross from above-to-below).
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
5
,
2
)
}
// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
func
TestCheckAccountQuotaAfterIncrement_NilAccount
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
// Should not panic.
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
nil
,
10
,
nil
)
}
func
TestCheckAccountQuotaAfterIncrement_ZeroCost
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
a
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
}
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
a
,
0
,
nil
)
}
func
TestCheckAccountQuotaAfterIncrement_NegativeCost
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
a
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
}
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
a
,
-
5
,
nil
)
}
func
TestCheckAccountQuotaAfterIncrement_GlobalDisabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyAccountQuotaNotifyEnabled
]
=
"false"
a
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"quota_notify_daily_enabled"
:
true
,
"quota_notify_daily_threshold"
:
100.0
,
"quota_daily_limit"
:
1000.0
,
"quota_daily_used"
:
950.0
,
},
}
// Global disabled → no processing even if a dim would cross.
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
a
,
100
,
nil
)
}
// ---------- sanity: internal helpers still work ----------
func
TestGetBalanceNotifyConfig_AllFields
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"12.5"
repo
.
data
[
SettingKeyBalanceLowNotifyRechargeURL
]
=
"https://example.com/pay"
enabled
,
threshold
,
url
:=
s
.
getBalanceNotifyConfig
(
context
.
Background
())
require
.
True
(
t
,
enabled
)
require
.
Equal
(
t
,
12.5
,
threshold
)
require
.
Equal
(
t
,
"https://example.com/pay"
,
url
)
}
func
TestGetBalanceNotifyConfig_Disabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"false"
enabled
,
_
,
_
:=
s
.
getBalanceNotifyConfig
(
context
.
Background
())
require
.
False
(
t
,
enabled
)
}
func
TestGetBalanceNotifyConfig_InvalidThreshold
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"not-a-number"
enabled
,
threshold
,
_
:=
s
.
getBalanceNotifyConfig
(
context
.
Background
())
require
.
True
(
t
,
enabled
)
require
.
Equal
(
t
,
0.0
,
threshold
)
}
func
TestIsAccountQuotaNotifyEnabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
// Missing key → false
require
.
False
(
t
,
s
.
isAccountQuotaNotifyEnabled
(
context
.
Background
()))
// Explicit "false"
repo
.
data
[
SettingKeyAccountQuotaNotifyEnabled
]
=
"false"
require
.
False
(
t
,
s
.
isAccountQuotaNotifyEnabled
(
context
.
Background
()))
// Explicit "true"
repo
.
data
[
SettingKeyAccountQuotaNotifyEnabled
]
=
"true"
require
.
True
(
t
,
s
.
isAccountQuotaNotifyEnabled
(
context
.
Background
()))
}
func
TestGetSiteName_FallsBackToDefault
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
name
:=
s
.
getSiteName
(
context
.
Background
())
require
.
Equal
(
t
,
defaultSiteName
,
name
)
}
func
TestGetSiteName_Configured
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeySiteName
]
=
"My Site"
require
.
Equal
(
t
,
"My Site"
,
s
.
getSiteName
(
context
.
Background
()))
}
// ---------- crossedDownward ----------
func
TestCrossedDownward_CrossesBelow
(
t
*
testing
.
T
)
{
// oldBalance > threshold, newBalance < threshold → true
require
.
True
(
t
,
crossedDownward
(
100
,
5
,
10
))
}
func
TestCrossedDownward_ExactlyAtThreshold
(
t
*
testing
.
T
)
{
// oldBalance > threshold, newBalance == threshold → false (not below)
require
.
False
(
t
,
crossedDownward
(
100
,
10
,
10
))
}
func
TestCrossedDownward_OldExactlyAtThreshold_NewBelow
(
t
*
testing
.
T
)
{
// oldBalance == threshold, newBalance < threshold → true
// (at-or-above → below counts as a crossing)
require
.
True
(
t
,
crossedDownward
(
10
,
5
,
10
))
}
func
TestCrossedDownward_AlreadyBelow
(
t
*
testing
.
T
)
{
// oldBalance < threshold → false (already below, no new crossing)
require
.
False
(
t
,
crossedDownward
(
5
,
3
,
10
))
}
func
TestCrossedDownward_BothAbove
(
t
*
testing
.
T
)
{
// oldBalance > threshold, newBalance > threshold → false (no crossing)
require
.
False
(
t
,
crossedDownward
(
100
,
50
,
10
))
}
func
TestCrossedDownward_ZeroThreshold
(
t
*
testing
.
T
)
{
// threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
// Typical case: positive balances should not fire when threshold is 0.
require
.
False
(
t
,
crossedDownward
(
10
,
5
,
0
))
require
.
False
(
t
,
crossedDownward
(
0
,
0
,
0
))
}
func
TestCrossedDownward_ZeroThreshold_NegativeNew
(
t
*
testing
.
T
)
{
// Edge case: newBalance goes negative with threshold=0.
require
.
True
(
t
,
crossedDownward
(
5
,
-
1
,
0
))
}
func
TestCrossedDownward_NegativeValues
(
t
*
testing
.
T
)
{
// Both already negative, threshold is positive → no crossing (already below).
require
.
False
(
t
,
crossedDownward
(
-
5
,
-
10
,
10
))
}
func
TestCrossedDownward_LargeDecrement
(
t
*
testing
.
T
)
{
// A single large deduction crosses the threshold.
require
.
True
(
t
,
crossedDownward
(
1000
,
0.5
,
100
))
}
func
TestCrossedDownward_SmallDecrement_NoCrossing
(
t
*
testing
.
T
)
{
// A tiny deduction stays above threshold.
require
.
False
(
t
,
crossedDownward
(
100
,
99.99
,
10
))
}
// ---------- checkQuotaDimCrossings ----------
func
TestCheckQuotaDimCrossings_NoDimensions
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// Empty dims → no crossing, no panic.
s
.
checkQuotaDimCrossings
(
account
,
nil
,
10
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
s
.
checkQuotaDimCrossings
(
account
,
[]
quotaDim
{},
10
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_DisabledDimension
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
false
,
// disabled
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
950
,
limit
:
1000
,
},
}
// Disabled dimension should be skipped even if crossing would occur.
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_ZeroThresholdSkipped
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
0
,
// zero threshold
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
950
,
limit
:
1000
,
},
}
// Zero threshold → skipped.
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
300
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
800
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
// Negative resolved threshold → skipped.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
1200
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
950
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
// currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimWeekly
,
enabled
:
true
,
threshold
:
30
,
thresholdType
:
thresholdTypePercentage
,
currentUsed
:
500
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_ZeroLimit_Skipped
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// limit=0 → resolvedThreshold returns 0 → skipped.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimTotal
,
enabled
:
true
,
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
50
,
limit
:
0
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_MultipleDims_MixedResults
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// dim1: no crossing (both below effective threshold)
// dim2: disabled (skipped)
// dim3: zero threshold (skipped)
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
300
,
// oldUsed=250, effectiveThreshold=600, both below
limit
:
1000
,
},
{
name
:
quotaDimWeekly
,
enabled
:
false
,
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
900
,
limit
:
1000
,
},
{
name
:
quotaDimTotal
,
enabled
:
true
,
threshold
:
0
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
500
,
limit
:
1000
,
},
}
// None should trigger. No panic expected.
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
backend/internal/service/balance_notify_email_body_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// These tests guard against fmt.Sprintf arg-count mismatches in the email
// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
// the output, which these assertions will catch.
// ---------- buildBalanceLowEmailBody ----------
func
TestBuildBalanceLowEmailBody_ContainsRequiredFields
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"Alice"
,
3.14
,
10.0
,
"MySite"
,
""
)
// All substituted values should appear in the output.
require
.
Contains
(
t
,
body
,
"MySite"
)
require
.
Contains
(
t
,
body
,
"Alice"
)
require
.
Contains
(
t
,
body
,
"$3.14"
)
require
.
Contains
(
t
,
body
,
"$10.00"
)
// No fmt.Sprintf format error markers.
require
.
NotContains
(
t
,
body
,
"%!"
)
require
.
NotContains
(
t
,
body
,
"MISSING"
)
require
.
NotContains
(
t
,
body
,
"EXTRA"
)
}
func
TestBuildBalanceLowEmailBody_WithRechargeURL
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"Bob"
,
5.0
,
20.0
,
"Site"
,
"https://example.com/pay"
)
// The recharge anchor element should appear with the URL.
require
.
Contains
(
t
,
body
,
`href="https://example.com/pay"`
)
require
.
Contains
(
t
,
body
,
"立即充值"
)
require
.
NotContains
(
t
,
body
,
"%!"
)
}
func
TestBuildBalanceLowEmailBody_RechargeURLEscaped
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
// Try a URL with characters that need HTML escaping.
body
:=
s
.
buildBalanceLowEmailBody
(
"u"
,
1.0
,
5.0
,
"Site"
,
`https://example.com/?a=1&b=<script>`
)
// `&` and `<` should be escaped in the href.
require
.
Contains
(
t
,
body
,
"&"
)
require
.
Contains
(
t
,
body
,
"<script>"
)
require
.
NotContains
(
t
,
body
,
"<script>"
)
}
func
TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"u"
,
1.0
,
5.0
,
"Site"
,
""
)
// The anchor element should not be rendered (style class may still appear).
require
.
NotContains
(
t
,
body
,
`<a href`
)
require
.
NotContains
(
t
,
body
,
"立即充值"
)
}
// ---------- buildQuotaAlertEmailBody ----------
func
TestBuildQuotaAlertEmailBody_AllFieldsPresent
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
42
,
// accountID
"acc-foo"
,
// accountName
"anthropic"
,
// platform
"日限额 / Daily"
,
// dimLabel
750.50
,
// used
1000.0
,
// limit
249.50
,
// remaining
"$249.50"
,
// thresholdDisplay
"MySite"
,
// siteName
)
require
.
Contains
(
t
,
body
,
"MySite"
)
require
.
Contains
(
t
,
body
,
"#42"
)
require
.
Contains
(
t
,
body
,
"acc-foo"
)
require
.
Contains
(
t
,
body
,
"anthropic"
)
require
.
Contains
(
t
,
body
,
"Daily"
)
require
.
Contains
(
t
,
body
,
"$750.50"
)
require
.
Contains
(
t
,
body
,
"$1000.00"
)
require
.
Contains
(
t
,
body
,
"$249.50"
)
// No format error markers.
require
.
NotContains
(
t
,
body
,
"%!"
)
require
.
NotContains
(
t
,
body
,
"MISSING"
)
require
.
NotContains
(
t
,
body
,
"EXTRA"
)
}
func
TestBuildQuotaAlertEmailBody_UnlimitedDisplay
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"dim"
,
100.0
,
0.0
,
// limit=0 triggers unlimited branch
0.0
,
"30%"
,
"Site"
,
)
require
.
Contains
(
t
,
body
,
"无限制"
)
require
.
Contains
(
t
,
body
,
"Unlimited"
)
}
func
TestBuildQuotaAlertEmailBody_PercentageThresholdDisplay
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"dim"
,
700.0
,
1000.0
,
300.0
,
"30%"
,
// percentage-formatted threshold
"Site"
,
)
require
.
Contains
(
t
,
body
,
"30%"
)
require
.
NotContains
(
t
,
body
,
"%!"
)
}
func
TestBuildQuotaAlertEmailBody_RemainingClampedAtZero
(
t
*
testing
.
T
)
{
// Even though caller is responsible for clamping, this test documents the
// display behavior with remaining=0.
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"dim"
,
1500.0
,
1000.0
,
0.0
,
// used > limit (over-quota)
"$100.00"
,
"Site"
,
)
require
.
Contains
(
t
,
body
,
"$0.00"
)
}
// ---------- sanity checks on the CSS `%%` escape ----------
func
TestBuildBalanceLowEmailBody_NoCSSFormatError
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"u"
,
1.0
,
5.0
,
"Site"
,
""
)
// CSS `linear-gradient(135deg, #f59e0b 0%, #d97706 100%)` should appear with
// literal percent signs (from the %% escape in the template).
require
.
True
(
t
,
strings
.
Contains
(
body
,
"0%"
)
&&
strings
.
Contains
(
body
,
"100%"
),
"CSS gradient percentages not rendered; got: %s"
,
body
)
}
func
TestBuildQuotaAlertEmailBody_NoCSSFormatError
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"d"
,
0
,
0
,
0
,
"$0.00"
,
"Site"
)
require
.
True
(
t
,
strings
.
Contains
(
body
,
"0%"
)
&&
strings
.
Contains
(
body
,
"100%"
),
"CSS gradient percentages not rendered; got: %s"
,
body
)
}
backend/internal/service/balance_notify_service.go
0 → 100644
View file @
0b746501
package
service
import
(
"context"
"fmt"
"html"
"log/slog"
"strconv"
"strings"
"time"
)
const
(
emailSendTimeout
=
30
*
time
.
Second
// Threshold type values
thresholdTypeFixed
=
"fixed"
thresholdTypePercentage
=
"percentage"
// Quota dimension labels
quotaDimDaily
=
"daily"
quotaDimWeekly
=
"weekly"
quotaDimTotal
=
"total"
defaultSiteName
=
"Sub2API"
)
// quotaDimLabels maps dimension names to display labels.
var
quotaDimLabels
=
map
[
string
]
string
{
quotaDimDaily
:
"日限额 / Daily"
,
quotaDimWeekly
:
"周限额 / Weekly"
,
quotaDimTotal
:
"总限额 / Total"
,
}
// AccountQuotaReader provides read access to account quota data.
type
AccountQuotaReader
interface
{
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
}
// BalanceNotifyService handles balance and quota threshold notifications.
type
BalanceNotifyService
struct
{
emailService
*
EmailService
settingRepo
SettingRepository
accountRepo
AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
func
NewBalanceNotifyService
(
emailService
*
EmailService
,
settingRepo
SettingRepository
,
accountRepo
AccountQuotaReader
)
*
BalanceNotifyService
{
return
&
BalanceNotifyService
{
emailService
:
emailService
,
settingRepo
:
settingRepo
,
accountRepo
:
accountRepo
,
}
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func
resolveBalanceThreshold
(
threshold
float64
,
thresholdType
string
,
totalRecharged
float64
)
float64
{
if
thresholdType
==
thresholdTypePercentage
&&
totalRecharged
>
0
{
return
totalRecharged
*
threshold
/
100
}
return
threshold
}
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
func
(
s
*
BalanceNotifyService
)
CheckBalanceAfterDeduction
(
ctx
context
.
Context
,
user
*
User
,
oldBalance
,
cost
float64
)
{
if
!
s
.
canNotifyBalance
(
user
)
{
return
}
effectiveThreshold
,
rechargeURL
,
ok
:=
s
.
resolveUserEffectiveThreshold
(
ctx
,
user
)
if
!
ok
{
return
}
newBalance
:=
oldBalance
-
cost
if
!
crossedDownward
(
oldBalance
,
newBalance
,
effectiveThreshold
)
{
return
}
s
.
dispatchBalanceLowEmail
(
ctx
,
user
,
newBalance
,
effectiveThreshold
,
rechargeURL
)
}
// canNotifyBalance checks nil guards and user-level toggle.
func
(
s
*
BalanceNotifyService
)
canNotifyBalance
(
user
*
User
)
bool
{
if
user
==
nil
||
s
.
emailService
==
nil
||
s
.
settingRepo
==
nil
{
return
false
}
return
user
.
BalanceNotifyEnabled
}
// resolveUserEffectiveThreshold reads global + user config, returns the effective threshold.
// Returns ok=false when notifications should be skipped.
func
(
s
*
BalanceNotifyService
)
resolveUserEffectiveThreshold
(
ctx
context
.
Context
,
user
*
User
)
(
effectiveThreshold
float64
,
rechargeURL
string
,
ok
bool
)
{
globalEnabled
,
globalThreshold
,
rechargeURL
:=
s
.
getBalanceNotifyConfig
(
ctx
)
if
!
globalEnabled
{
return
0
,
""
,
false
}
threshold
:=
globalThreshold
if
user
.
BalanceNotifyThreshold
!=
nil
{
threshold
=
*
user
.
BalanceNotifyThreshold
}
if
threshold
<=
0
{
return
0
,
""
,
false
}
effectiveThreshold
=
resolveBalanceThreshold
(
threshold
,
user
.
BalanceNotifyThresholdType
,
user
.
TotalRecharged
)
if
effectiveThreshold
<=
0
{
return
0
,
""
,
false
}
return
effectiveThreshold
,
rechargeURL
,
true
}
// crossedDownward returns true when oldV was at-or-above threshold but newV dropped below it.
func
crossedDownward
(
oldV
,
newV
,
threshold
float64
)
bool
{
return
oldV
>=
threshold
&&
newV
<
threshold
}
// dispatchBalanceLowEmail collects recipients and sends the alert in a goroutine.
func
(
s
*
BalanceNotifyService
)
dispatchBalanceLowEmail
(
ctx
context
.
Context
,
user
*
User
,
newBalance
,
threshold
float64
,
rechargeURL
string
)
{
siteName
:=
s
.
getSiteName
(
ctx
)
recipients
:=
s
.
collectBalanceNotifyRecipients
(
user
)
slog
.
Info
(
"CheckBalanceAfterDeduction: sending notification"
,
"user_id"
,
user
.
ID
,
"recipients"
,
recipients
,
"new_balance"
,
newBalance
,
"threshold"
,
threshold
)
go
func
()
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in balance notification"
,
"recover"
,
r
)
}
}()
s
.
sendBalanceLowEmails
(
recipients
,
user
.
Username
,
user
.
Email
,
newBalance
,
threshold
,
siteName
,
rechargeURL
)
}()
}
// quotaDim describes one quota dimension for notification checking.
type
quotaDim
struct
{
name
string
enabled
bool
threshold
float64
thresholdType
string
// "fixed" (default) or "percentage"
currentUsed
float64
limit
float64
}
// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point.
// The threshold represents how much quota REMAINS when the alert fires:
// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400)
// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%)
func
(
d
quotaDim
)
resolvedThreshold
()
float64
{
if
d
.
limit
<=
0
{
return
0
}
if
d
.
thresholdType
==
thresholdTypePercentage
{
return
d
.
limit
*
(
1
-
d
.
threshold
/
100
)
}
return
d
.
limit
-
d
.
threshold
}
// buildQuotaDims returns the three quota dimensions for notification checking.
func
buildQuotaDims
(
account
*
Account
)
[]
quotaDim
{
return
[]
quotaDim
{
{
quotaDimDaily
,
account
.
GetQuotaNotifyDailyEnabled
(),
account
.
GetQuotaNotifyDailyThreshold
(),
account
.
GetQuotaNotifyDailyThresholdType
(),
account
.
GetQuotaDailyUsed
(),
account
.
GetQuotaDailyLimit
()},
{
quotaDimWeekly
,
account
.
GetQuotaNotifyWeeklyEnabled
(),
account
.
GetQuotaNotifyWeeklyThreshold
(),
account
.
GetQuotaNotifyWeeklyThresholdType
(),
account
.
GetQuotaWeeklyUsed
(),
account
.
GetQuotaWeeklyLimit
()},
{
quotaDimTotal
,
account
.
GetQuotaNotifyTotalEnabled
(),
account
.
GetQuotaNotifyTotalThreshold
(),
account
.
GetQuotaNotifyTotalThresholdType
(),
account
.
GetQuotaUsed
(),
account
.
GetQuotaLimit
()},
}
}
// buildQuotaDimsFromState builds quota dimensions using DB transaction state instead of account snapshot.
// Notification settings (enabled, threshold, thresholdType) come from the account; usage values from quotaState.
func
buildQuotaDimsFromState
(
account
*
Account
,
state
*
AccountQuotaState
)
[]
quotaDim
{
return
[]
quotaDim
{
{
quotaDimDaily
,
account
.
GetQuotaNotifyDailyEnabled
(),
account
.
GetQuotaNotifyDailyThreshold
(),
account
.
GetQuotaNotifyDailyThresholdType
(),
state
.
DailyUsed
,
state
.
DailyLimit
},
{
quotaDimWeekly
,
account
.
GetQuotaNotifyWeeklyEnabled
(),
account
.
GetQuotaNotifyWeeklyThreshold
(),
account
.
GetQuotaNotifyWeeklyThresholdType
(),
state
.
WeeklyUsed
,
state
.
WeeklyLimit
},
{
quotaDimTotal
,
account
.
GetQuotaNotifyTotalEnabled
(),
account
.
GetQuotaNotifyTotalThreshold
(),
account
.
GetQuotaNotifyTotalThresholdType
(),
state
.
TotalUsed
,
state
.
TotalLimit
},
}
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// When quotaState is non-nil (from DB transaction RETURNING), it is used directly for threshold
// checking, avoiding a separate DB read. Otherwise it falls back to fetching fresh account data.
func
(
s
*
BalanceNotifyService
)
CheckAccountQuotaAfterIncrement
(
ctx
context
.
Context
,
account
*
Account
,
cost
float64
,
quotaState
*
AccountQuotaState
)
{
if
account
==
nil
||
s
.
emailService
==
nil
||
s
.
settingRepo
==
nil
||
cost
<=
0
{
return
}
if
!
s
.
isAccountQuotaNotifyEnabled
(
ctx
)
{
return
}
adminEmails
:=
s
.
getAccountQuotaNotifyEmails
(
ctx
)
if
len
(
adminEmails
)
==
0
{
return
}
siteName
:=
s
.
getSiteName
(
ctx
)
var
dims
[]
quotaDim
if
quotaState
!=
nil
{
dims
=
buildQuotaDimsFromState
(
account
,
quotaState
)
}
else
{
freshAccount
:=
s
.
fetchFreshAccount
(
ctx
,
account
)
dims
=
buildQuotaDims
(
freshAccount
)
account
=
freshAccount
// use fresh data for alert metadata
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
cost
,
adminEmails
,
siteName
)
}
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
func
(
s
*
BalanceNotifyService
)
fetchFreshAccount
(
ctx
context
.
Context
,
snapshot
*
Account
)
*
Account
{
if
s
.
accountRepo
==
nil
{
return
snapshot
}
fresh
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
snapshot
.
ID
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to fetch fresh account for quota notify, using snapshot"
,
"account_id"
,
snapshot
.
ID
,
"error"
,
err
)
return
snapshot
}
return
fresh
}
// checkQuotaDimCrossings iterates pre-built quota dimensions and sends alerts for threshold crossings.
// Pre-increment value is reconstructed as currentUsed - cost to detect the crossing moment.
func
(
s
*
BalanceNotifyService
)
checkQuotaDimCrossings
(
account
*
Account
,
dims
[]
quotaDim
,
cost
float64
,
adminEmails
[]
string
,
siteName
string
)
{
for
_
,
dim
:=
range
dims
{
if
!
dim
.
enabled
||
dim
.
threshold
<=
0
{
continue
}
effectiveThreshold
:=
dim
.
resolvedThreshold
()
if
effectiveThreshold
<=
0
{
continue
}
newUsed
:=
dim
.
currentUsed
oldUsed
:=
dim
.
currentUsed
-
cost
if
oldUsed
<
effectiveThreshold
&&
newUsed
>=
effectiveThreshold
{
s
.
asyncSendQuotaAlert
(
adminEmails
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
dim
,
newUsed
,
effectiveThreshold
,
siteName
)
}
}
}
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
func
(
s
*
BalanceNotifyService
)
asyncSendQuotaAlert
(
adminEmails
[]
string
,
accountID
int64
,
accountName
,
platform
string
,
dim
quotaDim
,
newUsed
,
effectiveThreshold
float64
,
siteName
string
)
{
go
func
()
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in quota notification"
,
"recover"
,
r
)
}
}()
s
.
sendQuotaAlertEmails
(
adminEmails
,
accountID
,
accountName
,
platform
,
dim
,
newUsed
,
siteName
)
}()
}
// getBalanceNotifyConfig reads global balance notification settings.
func
(
s
*
BalanceNotifyService
)
getBalanceNotifyConfig
(
ctx
context
.
Context
)
(
enabled
bool
,
threshold
float64
,
rechargeURL
string
)
{
keys
:=
[]
string
{
SettingKeyBalanceLowNotifyEnabled
,
SettingKeyBalanceLowNotifyThreshold
,
SettingKeyBalanceLowNotifyRechargeURL
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
false
,
0
,
""
}
enabled
=
settings
[
SettingKeyBalanceLowNotifyEnabled
]
==
"true"
if
v
:=
settings
[
SettingKeyBalanceLowNotifyThreshold
];
v
!=
""
{
if
f
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
threshold
=
f
}
}
rechargeURL
=
settings
[
SettingKeyBalanceLowNotifyRechargeURL
]
return
}
// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
func
(
s
*
BalanceNotifyService
)
isAccountQuotaNotifyEnabled
(
ctx
context
.
Context
)
bool
{
val
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAccountQuotaNotifyEnabled
)
if
err
!=
nil
{
return
false
}
return
val
==
"true"
}
// getAccountQuotaNotifyEmails reads admin notification emails from settings,
// filtering out disabled and unverified entries.
func
(
s
*
BalanceNotifyService
)
getAccountQuotaNotifyEmails
(
ctx
context
.
Context
)
[]
string
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAccountQuotaNotifyEmails
)
if
err
!=
nil
||
strings
.
TrimSpace
(
raw
)
==
""
||
raw
==
"[]"
{
return
nil
}
entries
:=
ParseNotifyEmails
(
raw
)
if
len
(
entries
)
==
0
{
return
nil
}
return
filterVerifiedEmails
(
entries
)
}
// getSiteName reads site name from settings with fallback.
func
(
s
*
BalanceNotifyService
)
getSiteName
(
ctx
context
.
Context
)
string
{
name
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
if
err
!=
nil
||
name
==
""
{
return
defaultSiteName
}
return
name
}
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
func
filterVerifiedEmails
(
entries
[]
NotifyEmailEntry
)
[]
string
{
var
recipients
[]
string
seen
:=
make
(
map
[
string
]
bool
)
for
_
,
entry
:=
range
entries
{
if
entry
.
Disabled
||
!
entry
.
Verified
{
continue
}
email
:=
strings
.
TrimSpace
(
entry
.
Email
)
if
email
==
""
{
continue
}
lower
:=
strings
.
ToLower
(
email
)
if
seen
[
lower
]
{
continue
}
seen
[
lower
]
=
true
recipients
=
append
(
recipients
,
email
)
}
return
recipients
}
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func
(
s
*
BalanceNotifyService
)
collectBalanceNotifyRecipients
(
user
*
User
)
[]
string
{
return
filterVerifiedEmails
(
user
.
BalanceNotifyExtraEmails
)
}
// sendEmails sends an email to all recipients with shared timeout and error logging.
func
(
s
*
BalanceNotifyService
)
sendEmails
(
recipients
[]
string
,
subject
,
body
string
,
logAttrs
...
any
)
{
if
len
(
recipients
)
==
0
{
slog
.
Warn
(
"sendEmails: no recipients"
,
"subject"
,
subject
)
return
}
for
_
,
to
:=
range
recipients
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
emailSendTimeout
)
if
err
:=
s
.
emailService
.
SendEmail
(
ctx
,
to
,
subject
,
body
);
err
!=
nil
{
attrs
:=
append
([]
any
{
"to"
,
to
,
"error"
,
err
},
logAttrs
...
)
slog
.
Error
(
"failed to send notification"
,
attrs
...
)
}
else
{
slog
.
Info
(
"notification email sent successfully"
,
"to"
,
to
,
"subject"
,
subject
)
}
cancel
()
}
}
// sendBalanceLowEmails sends balance low notification to all recipients.
func
(
s
*
BalanceNotifyService
)
sendBalanceLowEmails
(
recipients
[]
string
,
userName
,
userEmail
string
,
balance
,
threshold
float64
,
siteName
,
rechargeURL
string
)
{
displayName
:=
userName
if
displayName
==
""
{
displayName
=
userEmail
}
subject
:=
fmt
.
Sprintf
(
"[%s] 余额不足提醒 / Balance Low Alert"
,
sanitizeEmailHeader
(
siteName
))
body
:=
s
.
buildBalanceLowEmailBody
(
html
.
EscapeString
(
displayName
),
balance
,
threshold
,
html
.
EscapeString
(
siteName
),
rechargeURL
)
s
.
sendEmails
(
recipients
,
subject
,
body
,
"user_email"
,
userEmail
,
"balance"
,
balance
)
}
// sendQuotaAlertEmails sends quota alert notification to admin emails.
func
(
s
*
BalanceNotifyService
)
sendQuotaAlertEmails
(
adminEmails
[]
string
,
accountID
int64
,
accountName
,
platform
string
,
dim
quotaDim
,
used
float64
,
siteName
string
)
{
dimLabel
:=
quotaDimLabels
[
dim
.
name
]
if
dimLabel
==
""
{
dimLabel
=
dim
.
name
}
// Format the remaining-based threshold for display
thresholdDisplay
:=
fmt
.
Sprintf
(
"$%.2f"
,
dim
.
threshold
)
if
dim
.
thresholdType
==
thresholdTypePercentage
{
thresholdDisplay
=
fmt
.
Sprintf
(
"%.0f%%"
,
dim
.
threshold
)
}
remaining
:=
dim
.
limit
-
used
if
remaining
<
0
{
remaining
=
0
}
subject
:=
fmt
.
Sprintf
(
"[%s] 账号限额告警 / Account Quota Alert - %s"
,
sanitizeEmailHeader
(
siteName
),
sanitizeEmailHeader
(
accountName
))
body
:=
s
.
buildQuotaAlertEmailBody
(
accountID
,
html
.
EscapeString
(
accountName
),
html
.
EscapeString
(
platform
),
html
.
EscapeString
(
dimLabel
),
used
,
dim
.
limit
,
remaining
,
thresholdDisplay
,
html
.
EscapeString
(
siteName
))
s
.
sendEmails
(
adminEmails
,
subject
,
body
,
"account"
,
accountName
,
"dimension"
,
dim
.
name
)
}
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func
sanitizeEmailHeader
(
s
string
)
string
{
return
strings
.
NewReplacer
(
"
\r
"
,
""
,
"
\n
"
,
""
)
.
Replace
(
s
)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
// The recharge button is appended dynamically when rechargeURL is set.
const
balanceLowEmailTemplate
=
`<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.balance { font-size: 36px; font-weight: bold; color: #dc2626; margin: 20px 0; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.recharge-btn { display: inline-block; margin-top: 24px; padding: 12px 32px; background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: #fff; text-decoration: none; border-radius: 6px; font-size: 16px; font-weight: bold; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333;">%s,您的余额不足</p>
<p style="color: #666;">Dear %s, your balance is running low</p>
<div class="balance">$%.2f</div>
<div class="info">
<p>您的账户余额已低于提醒阈值 <strong>$%.2f</strong>。</p>
<p>Your account balance has fallen below the alert threshold of <strong>$%.2f</strong>.</p>
<p>请及时充值以免服务中断。</p>
<p>Please top up to avoid service interruption.</p>
</div>
%s
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay.
const
quotaAlertEmailTemplate
=
`<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #ef4444 0%%, #dc2626 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; }
.metric { display: flex; justify-content: space-between; padding: 12px 0; border-bottom: 1px solid #eee; }
.metric-label { color: #666; }
.metric-value { font-weight: bold; color: #333; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; text-align: center; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333; text-align: center;">账号限额告警 / Account Quota Alert</p>
<div class="metric"><span class="metric-label">账号 ID / Account ID</span><span class="metric-value">#%d</span></div>
<div class="metric"><span class="metric-label">账号 / Account</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">平台 / Platform</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">维度 / Dimension</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">已使用 / Used</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">限额 / Limit</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">剩余额度 / Remaining</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">提醒阈值 / Alert Threshold</span><span class="metric-value">%s</span></div>
<div class="info">
<p>账号剩余额度已低于提醒阈值,请及时关注。</p>
<p>Account remaining quota has fallen below the alert threshold.</p>
</div>
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// buildBalanceLowEmailBody builds HTML email for balance low notification.
func
(
s
*
BalanceNotifyService
)
buildBalanceLowEmailBody
(
userName
string
,
balance
,
threshold
float64
,
siteName
,
rechargeURL
string
)
string
{
rechargeBlock
:=
""
if
rechargeURL
!=
""
{
rechargeBlock
=
fmt
.
Sprintf
(
`<a href="%s" class="recharge-btn">立即充值 / Top Up Now</a>`
,
html
.
EscapeString
(
rechargeURL
))
}
return
fmt
.
Sprintf
(
balanceLowEmailTemplate
,
siteName
,
userName
,
userName
,
balance
,
threshold
,
threshold
,
rechargeBlock
)
}
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
func
(
s
*
BalanceNotifyService
)
buildQuotaAlertEmailBody
(
accountID
int64
,
accountName
,
platform
,
dimLabel
string
,
used
,
limit
,
remaining
float64
,
thresholdDisplay
,
siteName
string
)
string
{
limitStr
:=
fmt
.
Sprintf
(
"$%.2f"
,
limit
)
if
limit
<=
0
{
limitStr
=
"无限制 / Unlimited"
}
return
fmt
.
Sprintf
(
quotaAlertEmailTemplate
,
siteName
,
accountID
,
accountName
,
platform
,
dimLabel
,
used
,
limitStr
,
remaining
,
thresholdDisplay
)
}
backend/internal/service/balance_notify_service_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
// ---------- resolveBalanceThreshold ----------
func
TestResolveBalanceThreshold_Fixed
(
t
*
testing
.
T
)
{
// Fixed type always returns the raw threshold regardless of totalRecharged.
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
thresholdTypeFixed
,
1000
))
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
thresholdTypeFixed
,
0
))
require
.
Equal
(
t
,
0.0
,
resolveBalanceThreshold
(
0
,
thresholdTypeFixed
,
1000
))
}
func
TestResolveBalanceThreshold_Percentage
(
t
*
testing
.
T
)
{
// 10% of 1000 = 100
require
.
Equal
(
t
,
100.0
,
resolveBalanceThreshold
(
10
,
thresholdTypePercentage
,
1000
))
// 50% of 200 = 100
require
.
Equal
(
t
,
100.0
,
resolveBalanceThreshold
(
50
,
thresholdTypePercentage
,
200
))
}
func
TestResolveBalanceThreshold_PercentageZeroRecharged
(
t
*
testing
.
T
)
{
// When totalRecharged is 0, percentage falls through to raw threshold
// (treated as fixed). This is the defensive behavior.
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
thresholdTypePercentage
,
0
))
}
func
TestResolveBalanceThreshold_EmptyType
(
t
*
testing
.
T
)
{
// Empty type is treated as fixed (not percentage).
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
""
,
1000
))
}
// ---------- quotaDim.resolvedThreshold ----------
func
TestResolvedThreshold_FixedNormal
(
t
*
testing
.
T
)
{
// threshold=400 remaining, limit=1000 → usage trigger at 600
d
:=
quotaDim
{
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
limit
:
1000
}
require
.
Equal
(
t
,
600.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_FixedThresholdExceedsLimit
(
t
*
testing
.
T
)
{
// threshold=1200, limit=1000 → returns negative, callers must skip
d
:=
quotaDim
{
threshold
:
1200
,
thresholdType
:
thresholdTypeFixed
,
limit
:
1000
}
require
.
Equal
(
t
,
-
200.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_FixedThresholdEqualsLimit
(
t
*
testing
.
T
)
{
// threshold=1000, limit=1000 → returns 0 (alert fires at 0 usage)
d
:=
quotaDim
{
threshold
:
1000
,
thresholdType
:
thresholdTypeFixed
,
limit
:
1000
}
require
.
Equal
(
t
,
0.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_PercentageNormal
(
t
*
testing
.
T
)
{
// threshold=30%, limit=1000 → usage trigger at 700 (remaining drops to 30%)
d
:=
quotaDim
{
threshold
:
30
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
InDelta
(
t
,
700.0
,
d
.
resolvedThreshold
(),
0.001
)
}
func
TestResolvedThreshold_PercentageZeroPercent
(
t
*
testing
.
T
)
{
// threshold=0%, limit=1000 → fires when remaining drops to 0 (usage=1000)
d
:=
quotaDim
{
threshold
:
0
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
InDelta
(
t
,
1000.0
,
d
.
resolvedThreshold
(),
0.001
)
}
func
TestResolvedThreshold_PercentageHundredPercent
(
t
*
testing
.
T
)
{
// threshold=100%, limit=1000 → fires immediately (remaining drops to 100% i.e. nothing used yet)
d
:=
quotaDim
{
threshold
:
100
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
InDelta
(
t
,
0.0
,
d
.
resolvedThreshold
(),
0.001
)
}
func
TestResolvedThreshold_PercentageOverHundred
(
t
*
testing
.
T
)
{
// threshold=150%, limit=1000 → returns negative (never triggers; callers skip)
d
:=
quotaDim
{
threshold
:
150
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
Less
(
t
,
d
.
resolvedThreshold
(),
0.0
)
}
func
TestResolvedThreshold_ZeroLimit
(
t
*
testing
.
T
)
{
// limit=0 → returns 0 to avoid division and false alerts on unlimited quotas
d
:=
quotaDim
{
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
limit
:
0
}
require
.
Equal
(
t
,
0.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_NegativeLimit
(
t
*
testing
.
T
)
{
// Negative limit treated as 0
d
:=
quotaDim
{
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
limit
:
-
10
}
require
.
Equal
(
t
,
0.0
,
d
.
resolvedThreshold
())
}
// ---------- sanitizeEmailHeader ----------
func
TestSanitizeEmailHeader_CRLF
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"Subject injected"
,
sanitizeEmailHeader
(
"Subject
\r\n
injected"
))
}
func
TestSanitizeEmailHeader_OnlyCR
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"foobar"
,
sanitizeEmailHeader
(
"foo
\r
bar"
))
}
func
TestSanitizeEmailHeader_OnlyLF
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"foobar"
,
sanitizeEmailHeader
(
"foo
\n
bar"
))
}
func
TestSanitizeEmailHeader_Clean
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"Sub2API"
,
sanitizeEmailHeader
(
"Sub2API"
))
}
func
TestSanitizeEmailHeader_Empty
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
sanitizeEmailHeader
(
""
))
}
func
TestSanitizeEmailHeader_MultipleNewlines
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"abc"
,
sanitizeEmailHeader
(
"a
\r\n
b
\r\n
c"
))
}
// ---------- buildQuotaDims ----------
func
TestBuildQuotaDims_AllDimensionsReturned
(
t
*
testing
.
T
)
{
// Use an account with quota notify config across all 3 dimensions.
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"quota_notify_daily_enabled"
:
true
,
"quota_notify_daily_threshold"
:
100.0
,
"quota_notify_daily_threshold_type"
:
thresholdTypeFixed
,
"quota_notify_weekly_enabled"
:
true
,
"quota_notify_weekly_threshold"
:
20.0
,
"quota_notify_weekly_threshold_type"
:
thresholdTypePercentage
,
"quota_notify_total_enabled"
:
false
,
"quota_daily_limit"
:
500.0
,
"quota_weekly_limit"
:
2000.0
,
"quota_limit"
:
10000.0
,
"quota_daily_used"
:
50.0
,
"quota_weekly_used"
:
300.0
,
"quota_used"
:
1000.0
,
},
}
dims
:=
buildQuotaDims
(
a
)
require
.
Len
(
t
,
dims
,
3
)
// Daily
require
.
Equal
(
t
,
quotaDimDaily
,
dims
[
0
]
.
name
)
require
.
True
(
t
,
dims
[
0
]
.
enabled
)
require
.
Equal
(
t
,
100.0
,
dims
[
0
]
.
threshold
)
require
.
Equal
(
t
,
thresholdTypeFixed
,
dims
[
0
]
.
thresholdType
)
require
.
Equal
(
t
,
500.0
,
dims
[
0
]
.
limit
)
require
.
Equal
(
t
,
50.0
,
dims
[
0
]
.
currentUsed
)
// Weekly
require
.
Equal
(
t
,
quotaDimWeekly
,
dims
[
1
]
.
name
)
require
.
True
(
t
,
dims
[
1
]
.
enabled
)
require
.
Equal
(
t
,
20.0
,
dims
[
1
]
.
threshold
)
require
.
Equal
(
t
,
thresholdTypePercentage
,
dims
[
1
]
.
thresholdType
)
require
.
Equal
(
t
,
2000.0
,
dims
[
1
]
.
limit
)
// Total
require
.
Equal
(
t
,
quotaDimTotal
,
dims
[
2
]
.
name
)
require
.
False
(
t
,
dims
[
2
]
.
enabled
)
require
.
Equal
(
t
,
10000.0
,
dims
[
2
]
.
limit
)
require
.
Equal
(
t
,
1000.0
,
dims
[
2
]
.
currentUsed
)
}
func
TestBuildQuotaDims_EmptyExtra
(
t
*
testing
.
T
)
{
// Missing fields default to zero/disabled.
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{},
}
dims
:=
buildQuotaDims
(
a
)
require
.
Len
(
t
,
dims
,
3
)
for
_
,
d
:=
range
dims
{
require
.
False
(
t
,
d
.
enabled
)
require
.
Equal
(
t
,
0.0
,
d
.
threshold
)
require
.
Equal
(
t
,
0.0
,
d
.
limit
)
}
}
// ---------- buildQuotaDimsFromState ----------
func
TestBuildQuotaDimsFromState_UsesStateValues
(
t
*
testing
.
T
)
{
// Usage values should come from the state, not the account.
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"quota_notify_daily_enabled"
:
true
,
"quota_notify_daily_threshold"
:
100.0
,
"quota_daily_used"
:
999.0
,
// should be ignored
"quota_daily_limit"
:
999.0
,
// should be ignored
},
}
state
:=
&
AccountQuotaState
{
DailyUsed
:
77.0
,
DailyLimit
:
500.0
,
WeeklyUsed
:
88.0
,
WeeklyLimit
:
2000.0
,
TotalUsed
:
99.0
,
TotalLimit
:
10000.0
,
}
dims
:=
buildQuotaDimsFromState
(
a
,
state
)
require
.
Len
(
t
,
dims
,
3
)
// Settings from account (enabled, threshold, thresholdType)
require
.
True
(
t
,
dims
[
0
]
.
enabled
)
require
.
Equal
(
t
,
100.0
,
dims
[
0
]
.
threshold
)
// Usage from state
require
.
Equal
(
t
,
77.0
,
dims
[
0
]
.
currentUsed
)
require
.
Equal
(
t
,
500.0
,
dims
[
0
]
.
limit
)
require
.
Equal
(
t
,
88.0
,
dims
[
1
]
.
currentUsed
)
require
.
Equal
(
t
,
2000.0
,
dims
[
1
]
.
limit
)
require
.
Equal
(
t
,
99.0
,
dims
[
2
]
.
currentUsed
)
require
.
Equal
(
t
,
10000.0
,
dims
[
2
]
.
limit
)
}
// ---------- collectBalanceNotifyRecipients ----------
func
TestCollectBalanceNotifyRecipients_Empty
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
nil
}
require
.
Empty
(
t
,
s
.
collectBalanceNotifyRecipients
(
u
))
}
func
TestCollectBalanceNotifyRecipients_FiltersDisabledAndUnverified
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
"a@example.com"
,
Verified
:
true
,
Disabled
:
false
},
{
Email
:
"b@example.com"
,
Verified
:
true
,
Disabled
:
true
},
// disabled
{
Email
:
"c@example.com"
,
Verified
:
false
,
Disabled
:
false
},
// unverified
{
Email
:
"d@example.com"
,
Verified
:
true
,
Disabled
:
false
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Equal
(
t
,
[]
string
{
"a@example.com"
,
"d@example.com"
},
got
)
}
func
TestCollectBalanceNotifyRecipients_DeduplicatesCaseInsensitive
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
"User@Example.com"
,
Verified
:
true
},
{
Email
:
"user@example.com"
,
Verified
:
true
},
{
Email
:
"USER@EXAMPLE.COM"
,
Verified
:
true
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Len
(
t
,
got
,
1
)
// The original casing of the first entry is preserved.
require
.
Equal
(
t
,
"User@Example.com"
,
got
[
0
])
}
func
TestCollectBalanceNotifyRecipients_SkipsEmpty
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
" "
,
Verified
:
true
},
{
Email
:
""
,
Verified
:
true
},
{
Email
:
"valid@example.com"
,
Verified
:
true
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Equal
(
t
,
[]
string
{
"valid@example.com"
},
got
)
}
func
TestCollectBalanceNotifyRecipients_TrimsWhitespace
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
" trimmed@example.com "
,
Verified
:
true
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Equal
(
t
,
[]
string
{
"trimmed@example.com"
},
got
)
}
backend/internal/service/billing_service_test.go
View file @
0b746501
...
...
@@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require
.
InDelta
(
t
,
0.134
*
3
,
cost
.
ActualCost
,
1e-10
)
}
func
TestIsModelSupported
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
...
...
@@ -719,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
require
.
InDelta
(
t
,
1.5
,
pricing
.
LongContextInputMultiplier
,
1e-12
)
require
.
InDelta
(
t
,
1.25
,
pricing
.
LongContextOutputMultiplier
,
1e-12
)
}
// ---------------------------------------------------------------------------
// GetModelPricingWithChannel
// ---------------------------------------------------------------------------
func
TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
pricing
)
// Should be identical to GetModelPricing
original
,
err
:=
svc
.
GetModelPricing
(
"claude-sonnet-4"
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
original
.
InputPricePerToken
,
pricing
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
original
.
OutputPricePerToken
,
pricing
.
OutputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
original
.
CacheCreationPricePerToken
,
pricing
.
CacheCreationPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
original
.
CacheReadPricePerToken
,
pricing
.
CacheReadPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_OverrideInputPriceOnly
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
99e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// InputPrice overridden (both normal and priority)
require
.
InDelta
(
t
,
99e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
99e-6
,
pricing
.
InputPricePerTokenPriority
,
1e-12
)
// OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6)
require
.
InDelta
(
t
,
15e-6
,
pricing
.
OutputPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_OverrideOutputPriceOnly
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
OutputPrice
:
testPtrFloat64
(
88e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// OutputPrice overridden
require
.
InDelta
(
t
,
88e-6
,
pricing
.
OutputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
88e-6
,
pricing
.
OutputPricePerTokenPriority
,
1e-12
)
// InputPrice unchanged (claude-sonnet-4 fallback = 3e-6)
require
.
InDelta
(
t
,
3e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_OverrideAllFields
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
10e-6
),
OutputPrice
:
testPtrFloat64
(
20e-6
),
CacheWritePrice
:
testPtrFloat64
(
5e-6
),
CacheReadPrice
:
testPtrFloat64
(
1e-6
),
ImageOutputPrice
:
testPtrFloat64
(
50e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
10e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
10e-6
,
pricing
.
InputPricePerTokenPriority
,
1e-12
)
require
.
InDelta
(
t
,
20e-6
,
pricing
.
OutputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
20e-6
,
pricing
.
OutputPricePerTokenPriority
,
1e-12
)
require
.
InDelta
(
t
,
5e-6
,
pricing
.
CacheCreationPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
5e-6
,
pricing
.
CacheCreation5mPrice
,
1e-12
)
require
.
InDelta
(
t
,
5e-6
,
pricing
.
CacheCreation1hPrice
,
1e-12
)
require
.
InDelta
(
t
,
1e-6
,
pricing
.
CacheReadPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
1e-6
,
pricing
.
CacheReadPricePerTokenPriority
,
1e-12
)
require
.
InDelta
(
t
,
50e-6
,
pricing
.
ImageOutputPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
CacheWritePrice
:
testPtrFloat64
(
7e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h
require
.
InDelta
(
t
,
7e-6
,
pricing
.
CacheCreationPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
7e-6
,
pricing
.
CacheCreation5mPrice
,
1e-12
)
require
.
InDelta
(
t
,
7e-6
,
pricing
.
CacheCreation1hPrice
,
1e-12
)
}
func
TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
CacheReadPrice
:
testPtrFloat64
(
2e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// CacheReadPrice should set both normal and priority
require
.
InDelta
(
t
,
2e-6
,
pricing
.
CacheReadPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
2e-6
,
pricing
.
CacheReadPricePerTokenPriority
,
1e-12
)
}
func
TestGetModelPricingWithChannel_UnknownModelReturnsError
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
1e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"totally-unknown-model"
,
chPricing
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
pricing
)
require
.
Contains
(
t
,
err
.
Error
(),
"pricing not found"
)
}
backend/internal/service/billing_service_unified_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// CalculateCostUnified
// ---------------------------------------------------------------------------
func
TestCalculateCostUnified_NilResolver_FallsBackToOldPath
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
input
:=
CostInput
{
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.0
,
Resolver
:
nil
,
// no resolver
}
cost
,
err
:=
svc
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
// Should match the old-path result exactly
expected
,
err
:=
svc
.
calculateCostInternal
(
"claude-sonnet-4"
,
tokens
,
1.0
,
""
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
expected
.
TotalCost
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
require
.
Empty
(
t
,
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_TokenMode
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
input
:=
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.5
,
Resolver
:
resolver
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
expectedTotal
:=
1000
*
3e-6
+
500
*
15e-6
require
.
InDelta
(
t
,
expectedTotal
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
expectedTotal
*
1.5
,
cost
.
ActualCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModeToken
),
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_PerRequestMode
(
t
*
testing
.
T
)
{
// Set up a ChannelService with a per-request pricing channel
cs
:=
newTestChannelServiceWithCache
(
t
,
&
channelCache
{
pricingByGroupModel
:
map
[
channelModelKey
]
*
ChannelModelPricing
{
{
groupID
:
1
,
model
:
"claude-sonnet-4"
}
:
{
BillingMode
:
BillingModePerRequest
,
PerRequestPrice
:
testPtrFloat64
(
0.05
),
},
},
channelByGroupID
:
map
[
int64
]
*
Channel
{
1
:
{
ID
:
1
,
Status
:
StatusActive
},
},
groupPlatform
:
map
[
int64
]
string
{
1
:
""
},
wildcardByGroupPlatform
:
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
{},
mappingByGroupModel
:
map
[
channelModelKey
]
string
{},
wildcardMappingByGP
:
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
{},
byID
:
map
[
int64
]
*
Channel
{},
})
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
cs
,
bs
)
groupID
:=
int64
(
1
)
input
:=
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
GroupID
:
&
groupID
,
Tokens
:
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
},
RequestCount
:
3
,
RateMultiplier
:
2.0
,
Resolver
:
resolver
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// 3 requests * $0.05 = $0.15
require
.
InDelta
(
t
,
0.15
,
cost
.
TotalCost
,
1e-10
)
// ActualCost = 0.15 * 2.0 = 0.30
require
.
InDelta
(
t
,
0.30
,
cost
.
ActualCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModePerRequest
),
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_ImageMode
(
t
*
testing
.
T
)
{
cs
:=
newTestChannelServiceWithCache
(
t
,
&
channelCache
{
pricingByGroupModel
:
map
[
channelModelKey
]
*
ChannelModelPricing
{
{
groupID
:
2
,
model
:
"gemini-image"
}
:
{
BillingMode
:
BillingModeImage
,
PerRequestPrice
:
testPtrFloat64
(
0.10
),
},
},
channelByGroupID
:
map
[
int64
]
*
Channel
{
2
:
{
ID
:
2
,
Status
:
StatusActive
},
},
groupPlatform
:
map
[
int64
]
string
{
2
:
""
},
wildcardByGroupPlatform
:
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
{},
mappingByGroupModel
:
map
[
channelModelKey
]
string
{},
wildcardMappingByGP
:
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
{},
byID
:
map
[
int64
]
*
Channel
{},
})
bs
:=
&
BillingService
{
cfg
:
&
config
.
Config
{},
fallbackPrices
:
map
[
string
]
*
ModelPricing
{},
}
resolver
:=
NewModelPricingResolver
(
cs
,
bs
)
groupID
:=
int64
(
2
)
input
:=
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"gemini-image"
,
GroupID
:
&
groupID
,
Tokens
:
UsageTokens
{},
RequestCount
:
2
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// 2 * $0.10 = $0.20
require
.
InDelta
(
t
,
0.20
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
0.20
,
cost
.
ActualCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModeImage
),
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
costZero
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
0
,
// should default to 1.0
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
costOne
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
costOne
.
ActualCost
,
costZero
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
}
costNeg
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
-
5.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
costOne
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
costOne
.
ActualCost
,
costNeg
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostUnified_BillingModeFieldFilled
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
cost
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
UsageTokens
{
InputTokens
:
100
},
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"token"
,
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_UsesPreResolvedPricing
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
preResolved
:=
&
ResolvedPricing
{
Mode
:
BillingModePerRequest
,
DefaultPerRequestPrice
:
0.07
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
UsageTokens
{
InputTokens
:
100
},
RequestCount
:
2
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
Resolved
:
preResolved
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// 2 * $0.07 = $0.14
require
.
InDelta
(
t
,
0.14
,
cost
.
TotalCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModePerRequest
),
cost
.
BillingMode
)
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
// cache snapshot, bypassing the repository layer entirely.
func
newTestChannelServiceWithCache
(
t
*
testing
.
T
,
cache
*
channelCache
)
*
ChannelService
{
t
.
Helper
()
cs
:=
&
ChannelService
{}
cache
.
loadedAt
=
time
.
Now
()
cs
.
cache
.
Store
(
cache
)
return
cs
}
backend/internal/service/channel.go
View file @
0b746501
...
...
@@ -39,6 +39,8 @@ type Channel struct {
Status
string
BillingModelSource
string
// "requested", "upstream", or "channel_mapped"
RestrictModels
bool
// 是否限制模型(仅允许定价列表中的模型)
Features
string
// 渠道特性描述(JSON 数组),用于支付页面展示
FeaturesConfig
map
[
string
]
any
// 渠道功能配置(如 web search emulation)
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
@@ -48,6 +50,25 @@ type Channel struct {
ModelPricing
[]
ChannelModelPricing
// 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping
map
[
string
]
map
[
string
]
string
// 账号统计定价
ApplyPricingToAccountStats
bool
// 是否应用渠道模型定价到账号统计
AccountStatsPricingRules
[]
AccountStatsPricingRule
// 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
}
// AccountStatsPricingRule 账号统计定价规则
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
// 多条规则按 SortOrder 排序,先命中为准。
type
AccountStatsPricingRule
struct
{
ID
int64
ChannelID
int64
Name
string
GroupIDs
[]
int64
AccountIDs
[]
int64
SortOrder
int
Pricing
[]
ChannelModelPricing
// 规则内的模型定价(复用现有定价结构)
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
}
// ChannelModelPricing 渠道模型定价条目
...
...
@@ -176,9 +197,58 @@ func (c *Channel) Clone() *Channel {
cp
.
ModelMapping
[
platform
]
=
inner
}
}
if
c
.
FeaturesConfig
!=
nil
{
cp
.
FeaturesConfig
=
deepCopyFeaturesConfig
(
c
.
FeaturesConfig
)
}
if
c
.
AccountStatsPricingRules
!=
nil
{
cp
.
AccountStatsPricingRules
=
make
([]
AccountStatsPricingRule
,
len
(
c
.
AccountStatsPricingRules
))
for
i
,
rule
:=
range
c
.
AccountStatsPricingRules
{
cp
.
AccountStatsPricingRules
[
i
]
=
rule
if
rule
.
GroupIDs
!=
nil
{
cp
.
AccountStatsPricingRules
[
i
]
.
GroupIDs
=
make
([]
int64
,
len
(
rule
.
GroupIDs
))
copy
(
cp
.
AccountStatsPricingRules
[
i
]
.
GroupIDs
,
rule
.
GroupIDs
)
}
if
rule
.
AccountIDs
!=
nil
{
cp
.
AccountStatsPricingRules
[
i
]
.
AccountIDs
=
make
([]
int64
,
len
(
rule
.
AccountIDs
))
copy
(
cp
.
AccountStatsPricingRules
[
i
]
.
AccountIDs
,
rule
.
AccountIDs
)
}
if
rule
.
Pricing
!=
nil
{
cp
.
AccountStatsPricingRules
[
i
]
.
Pricing
=
make
([]
ChannelModelPricing
,
len
(
rule
.
Pricing
))
for
j
:=
range
rule
.
Pricing
{
cp
.
AccountStatsPricingRules
[
i
]
.
Pricing
[
j
]
=
rule
.
Pricing
[
j
]
.
Clone
()
}
}
}
}
return
&
cp
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func
(
c
*
Channel
)
IsWebSearchEmulationEnabled
(
platform
string
)
bool
{
if
c
==
nil
||
c
.
FeaturesConfig
==
nil
{
return
false
}
wse
,
ok
:=
c
.
FeaturesConfig
[
featureKeyWebSearchEmulation
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
false
}
enabled
,
ok
:=
wse
[
platform
]
.
(
bool
)
return
ok
&&
enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func
deepCopyFeaturesConfig
(
src
map
[
string
]
any
)
map
[
string
]
any
{
dst
:=
make
(
map
[
string
]
any
,
len
(
src
))
for
k
,
v
:=
range
src
{
if
inner
,
ok
:=
v
.
(
map
[
string
]
any
);
ok
{
dst
[
k
]
=
deepCopyFeaturesConfig
(
inner
)
}
else
{
dst
[
k
]
=
v
}
}
return
dst
}
// ValidateIntervals 校验区间列表的合法性。
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
...
...
backend/internal/service/channel_service.go
View file @
0b746501
...
...
@@ -81,9 +81,9 @@ type wildcardMappingEntry struct {
type
channelCache
struct
{
// 热路径查找
pricingByGroupModel
map
[
channelModelKey
]
*
ChannelModelPricing
// (groupID, platform, model) → 定价
wildcardByGroupPlatform
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
// (groupID, platform) → 通配符定价(
前缀长度降序
)
wildcardByGroupPlatform
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
// (groupID, platform) → 通配符定价(
按配置顺序,先匹配先使用
)
mappingByGroupModel
map
[
channelModelKey
]
string
// (groupID, platform, model) → 映射目标
wildcardMappingByGP
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
// (groupID, platform) → 通配符映射(
前缀长度降序
)
wildcardMappingByGP
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
// (groupID, platform) → 通配符映射(
按配置顺序,先匹配先使用
)
channelByGroupID
map
[
int64
]
*
Channel
// groupID → 渠道
groupPlatform
map
[
int64
]
string
// groupID → platform
...
...
@@ -315,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
expandMappingToCache
(
cache
,
ch
,
gid
,
platform
)
}
}
return
cache
}
...
...
@@ -415,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return
ch
.
Clone
(),
nil
}
// GetGroupPlatform 获取分组的平台标识(从缓存)
func
(
s
*
ChannelService
)
GetGroupPlatform
(
ctx
context
.
Context
,
groupID
int64
)
string
{
cache
,
err
:=
s
.
loadCache
(
ctx
)
if
err
!=
nil
{
return
""
}
return
cache
.
groupPlatform
[
groupID
]
}
// channelLookup 热路径公共查找结果
type
channelLookup
struct
{
cache
*
channelCache
...
...
@@ -556,13 +566,19 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func
validateChannelConfig
(
pricing
[]
ChannelModelPricing
,
mapping
map
[
string
]
map
[
string
]
string
)
error
{
if
err
:=
validate
NoConfl
ic
t
ing
Model
s
(
pricing
);
err
!=
nil
{
if
err
:=
validate
Pr
icing
Entrie
s
(
pricing
);
err
!=
nil
{
return
err
}
if
err
:=
validatePricingIntervals
(
pricing
);
err
!=
nil
{
return
validateNoConflictingMappings
(
mapping
)
}
// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验),
// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。
func
validatePricingEntries
(
pricing
[]
ChannelModelPricing
)
error
{
if
err
:=
validateNoConflictingModels
(
pricing
);
err
!=
nil
{
return
err
}
if
err
:=
validate
NoConfl
ic
t
ing
Mappings
(
mapp
ing
);
err
!=
nil
{
if
err
:=
validate
Pr
icing
Intervals
(
pric
ing
);
err
!=
nil
{
return
err
}
return
validatePricingBillingMode
(
pricing
)
...
...
@@ -663,6 +679,10 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
GroupIDs
:
input
.
GroupIDs
,
ModelPricing
:
input
.
ModelPricing
,
ModelMapping
:
input
.
ModelMapping
,
Features
:
input
.
Features
,
FeaturesConfig
:
input
.
FeaturesConfig
,
ApplyPricingToAccountStats
:
input
.
ApplyPricingToAccountStats
,
AccountStatsPricingRules
:
input
.
AccountStatsPricingRules
,
}
if
channel
.
BillingModelSource
==
""
{
channel
.
BillingModelSource
=
BillingModelSourceChannelMapped
...
...
@@ -671,6 +691,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if
err
:=
validateChannelConfig
(
channel
.
ModelPricing
,
channel
.
ModelMapping
);
err
!=
nil
{
return
nil
,
err
}
for
i
,
rule
:=
range
channel
.
AccountStatsPricingRules
{
if
err
:=
validatePricingEntries
(
rule
.
Pricing
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"account stats pricing rule #%d: %w"
,
i
+
1
,
err
)
}
}
if
err
:=
s
.
repo
.
Create
(
ctx
,
channel
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create channel: %w"
,
err
)
...
...
@@ -699,6 +724,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if
err
:=
validateChannelConfig
(
channel
.
ModelPricing
,
channel
.
ModelMapping
);
err
!=
nil
{
return
nil
,
err
}
for
i
,
rule
:=
range
channel
.
AccountStatsPricingRules
{
if
err
:=
validatePricingEntries
(
rule
.
Pricing
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"account stats pricing rule #%d: %w"
,
i
+
1
,
err
)
}
}
oldGroupIDs
:=
s
.
getOldGroupIDs
(
ctx
,
id
)
...
...
@@ -733,6 +763,9 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if
input
.
RestrictModels
!=
nil
{
channel
.
RestrictModels
=
*
input
.
RestrictModels
}
if
input
.
Features
!=
nil
{
channel
.
Features
=
*
input
.
Features
}
if
input
.
GroupIDs
!=
nil
{
if
err
:=
s
.
checkGroupConflicts
(
ctx
,
channel
.
ID
,
*
input
.
GroupIDs
);
err
!=
nil
{
return
err
...
...
@@ -748,6 +781,15 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if
input
.
BillingModelSource
!=
""
{
channel
.
BillingModelSource
=
input
.
BillingModelSource
}
if
input
.
FeaturesConfig
!=
nil
{
channel
.
FeaturesConfig
=
input
.
FeaturesConfig
}
if
input
.
ApplyPricingToAccountStats
!=
nil
{
channel
.
ApplyPricingToAccountStats
=
*
input
.
ApplyPricingToAccountStats
}
if
input
.
AccountStatsPricingRules
!=
nil
{
channel
.
AccountStatsPricingRules
=
*
input
.
AccountStatsPricingRules
}
return
nil
}
...
...
@@ -920,6 +962,10 @@ type CreateChannelInput struct {
ModelMapping
map
[
string
]
map
[
string
]
string
// platform → {src→dst}
BillingModelSource
string
RestrictModels
bool
Features
string
FeaturesConfig
map
[
string
]
any
ApplyPricingToAccountStats
bool
AccountStatsPricingRules
[]
AccountStatsPricingRule
}
// UpdateChannelInput 更新渠道输入
...
...
@@ -932,4 +978,8 @@ type UpdateChannelInput struct {
ModelMapping
map
[
string
]
map
[
string
]
string
// platform → {src→dst}
BillingModelSource
string
RestrictModels
*
bool
Features
*
string
FeaturesConfig
map
[
string
]
any
ApplyPricingToAccountStats
*
bool
AccountStatsPricingRules
*
[]
AccountStatsPricingRule
}
backend/internal/service/channel_websearch_test.go
0 → 100644
View file @
0b746501
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestChannel_IsWebSearchEmulationEnabled_Enabled
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
true
},
},
}
require
.
True
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
true
},
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"openai"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_Disabled
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
false
},
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
nil
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_NilChannel
(
t
*
testing
.
T
)
{
var
c
*
Channel
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_WrongStructure
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
true
,
// not a map
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
"yes"
},
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
backend/internal/service/concurrency_service.go
View file @
0b746501
...
...
@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
map
[
int64
]
int
{},
nil
...
...
@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
return
result
,
nil
}
return
s
.
cache
.
GetAccountConcurrencyBatch
(
ctx
,
accountIDs
)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
defer
cancel
()
return
s
.
cache
.
GetAccountConcurrencyBatch
(
redisCtx
,
accountIDs
)
}
backend/internal/service/domain_constants.go
View file @
0b746501
...
...
@@ -249,6 +249,18 @@ const (
SettingKeyEnableMetadataPassthrough
=
"enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning
=
"enable_cch_signing"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled
=
"balance_low_notify_enabled"
// 全局开关
SettingKeyBalanceLowNotifyThreshold
=
"balance_low_notify_threshold"
// 默认阈值(USD)
SettingKeyBalanceLowNotifyRechargeURL
=
"balance_low_notify_recharge_url"
// 充值页面 URL
// Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled
=
"account_quota_notify_enabled"
// 全局开关
SettingKeyAccountQuotaNotifyEmails
=
"account_quota_notify_emails"
// 管理员通知邮箱列表(JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig
=
"web_search_emulation_config"
// JSON 配置
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
...
...
backend/internal/service/email_service.go
View file @
0b746501
...
...
@@ -7,8 +7,9 @@ import (
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"log
/slog
"
"math/big"
"net"
"net/smtp"
"net/url"
"strconv"
...
...
@@ -34,6 +35,11 @@ type EmailCache interface {
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Notify email verification code methods
GetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
...
...
@@ -43,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
// Notify code rate limiting per user
IncrNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
,
window
time
.
Duration
)
(
int64
,
error
)
GetNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
}
// VerificationCodeData represents verification code data
...
...
@@ -50,6 +60,7 @@ type VerificationCodeData struct {
Code
string
Attempts
int
CreatedAt
time
.
Time
ExpiresAt
time
.
Time
// absolute expiry; used to preserve remaining TTL when updating attempts
}
// PasswordResetTokenData represents password reset token data
...
...
@@ -146,11 +157,18 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
return
s
.
SendEmailWithConfig
(
config
,
to
,
subject
,
body
)
}
const
smtpDialTimeout
=
10
*
time
.
Second
const
smtpIOTimeout
=
20
*
time
.
Second
// SendEmailWithConfig 使用指定配置发送邮件
func
(
s
*
EmailService
)
SendEmailWithConfig
(
config
*
SMTPConfig
,
to
,
subject
,
body
string
)
error
{
from
:=
config
.
From
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
to
=
sanitizeEmailHeader
(
to
)
subject
=
sanitizeEmailHeader
(
subject
)
from
:=
sanitizeEmailHeader
(
config
.
From
)
if
config
.
FromName
!=
""
{
from
=
fmt
.
Sprintf
(
"%s <%s>"
,
config
.
FromName
,
config
.
From
)
from
=
fmt
.
Sprintf
(
"%s <%s>"
,
sanitizeEmailHeader
(
config
.
FromName
)
,
sanitizeEmailHeader
(
config
.
From
)
)
}
msg
:=
fmt
.
Sprintf
(
"From: %s
\r\n
To: %s
\r\n
Subject: %s
\r\n
MIME-Version: 1.0
\r\n
Content-Type: text/html; charset=UTF-8
\r\n\r\n
%s"
,
...
...
@@ -163,7 +181,54 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
return
s
.
sendMailTLS
(
addr
,
auth
,
config
.
From
,
to
,
[]
byte
(
msg
),
config
.
Host
)
}
return
smtp
.
SendMail
(
addr
,
auth
,
config
.
From
,
[]
string
{
to
},
[]
byte
(
msg
))
return
s
.
sendMailPlain
(
addr
,
auth
,
config
.
From
,
to
,
[]
byte
(
msg
),
config
.
Host
)
}
// sendMailPlain sends mail without TLS using a dialer with timeout.
func
(
s
*
EmailService
)
sendMailPlain
(
addr
string
,
auth
smtp
.
Auth
,
from
,
to
string
,
msg
[]
byte
,
host
string
)
error
{
dialer
:=
&
net
.
Dialer
{
Timeout
:
smtpDialTimeout
}
conn
,
err
:=
dialer
.
Dial
(
"tcp"
,
addr
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp dial: %w"
,
err
)
}
_
=
conn
.
SetDeadline
(
time
.
Now
()
.
Add
(
smtpIOTimeout
))
defer
func
()
{
_
=
conn
.
Close
()
}()
client
,
err
:=
smtp
.
NewClient
(
conn
,
host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"new smtp client: %w"
,
err
)
}
defer
func
()
{
_
=
client
.
Close
()
}()
// Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it.
// This mirrors the behavior of smtp.SendMail which we replaced for timeout support.
if
ok
,
_
:=
client
.
Extension
(
"STARTTLS"
);
ok
{
if
err
=
client
.
StartTLS
(
&
tls
.
Config
{
ServerName
:
host
,
MinVersion
:
tls
.
VersionTLS12
});
err
!=
nil
{
return
fmt
.
Errorf
(
"starttls: %w"
,
err
)
}
}
if
err
=
client
.
Auth
(
auth
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp auth: %w"
,
err
)
}
if
err
=
client
.
Mail
(
from
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp mail: %w"
,
err
)
}
if
err
=
client
.
Rcpt
(
to
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp rcpt: %w"
,
err
)
}
w
,
err
:=
client
.
Data
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp data: %w"
,
err
)
}
if
_
,
err
=
w
.
Write
(
msg
);
err
!=
nil
{
return
fmt
.
Errorf
(
"write msg: %w"
,
err
)
}
if
err
=
w
.
Close
();
err
!=
nil
{
return
fmt
.
Errorf
(
"close writer: %w"
,
err
)
}
_
=
client
.
Quit
()
return
nil
}
// sendMailTLS 使用TLS发送邮件
...
...
@@ -174,10 +239,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
MinVersion
:
tls
.
VersionTLS12
,
}
conn
,
err
:=
tls
.
Dial
(
"tcp"
,
addr
,
tlsConfig
)
dialer
:=
&
net
.
Dialer
{
Timeout
:
smtpDialTimeout
}
conn
,
err
:=
tls
.
DialWithDialer
(
dialer
,
"tcp"
,
addr
,
tlsConfig
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"tls dial: %w"
,
err
)
}
_
=
conn
.
SetDeadline
(
time
.
Now
()
.
Add
(
smtpIOTimeout
))
defer
func
()
{
_
=
conn
.
Close
()
}()
client
,
err
:=
smtp
.
NewClient
(
conn
,
host
)
...
...
@@ -254,6 +321,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code
:
code
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
(),
ExpiresAt
:
time
.
Now
()
.
Add
(
verifyCodeTTL
),
}
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save verify code: %w"
,
err
)
...
...
@@ -286,8 +354,12 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
remaining
:=
time
.
Until
(
data
.
ExpiresAt
)
if
remaining
<=
0
{
return
ErrInvalidVerifyCode
}
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
remaining
);
err
!=
nil
{
slog
.
Error
(
"failed to update verification attempt count"
,
"email"
,
email
,
"error"
,
err
)
}
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
...
...
@@ -297,7 +369,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码
if
err
:=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] F
ailed to delete verification code after success
: %v
"
,
err
)
s
log
.
Error
(
"f
ailed to delete verification code after success
"
,
"email"
,
email
,
"error
"
,
err
)
}
return
nil
}
...
...
@@ -447,7 +519,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] P
assword reset email skipped
(
cooldown
): %s
"
,
email
)
s
log
.
Info
(
"p
assword reset email skipped
due to
cooldown
"
,
"email
"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
...
...
@@ -458,7 +530,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] F
ailed to set password reset cooldown
for %s: %v"
,
email
,
err
)
s
log
.
Error
(
"f
ailed to set password reset cooldown
"
,
"email"
,
email
,
"error"
,
err
)
}
return
nil
...
...
@@ -488,7 +560,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] F
ailed to delete password reset token after consumption
: %v
"
,
err
)
s
log
.
Error
(
"f
ailed to delete password reset token after consumption
"
,
"email"
,
email
,
"error
"
,
err
)
}
return
nil
}
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment