Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0b746501
Commit
0b746501
authored
Apr 16, 2026
by
陈曦
Browse files
1. merge upstream v0.1.113 2.提交migration相关文件
parents
45061102
be7551b9
Changes
225
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/account_usage_service_test.go
View file @
0b746501
...
@@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
...
@@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
}
}
}
}
func
TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt
(
t
*
testing
.
T
)
{
func
TestAccountUsageService_PersistOpenAICodexProbeSnapshotOnlyUpdatesExtra
(
t
*
testing
.
T
)
{
t
.
Parallel
()
headers
:=
make
(
http
.
Header
)
headers
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"604800"
)
headers
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
headers
.
Set
(
"x-codex-secondary-used-percent"
,
"100"
)
headers
.
Set
(
"x-codex-secondary-reset-after-seconds"
,
"18000"
)
headers
.
Set
(
"x-codex-secondary-window-minutes"
,
"300"
)
updates
,
resetAt
,
err
:=
extractOpenAICodexProbeSnapshot
(
&
http
.
Response
{
StatusCode
:
http
.
StatusTooManyRequests
,
Header
:
headers
})
if
err
!=
nil
{
t
.
Fatalf
(
"extractOpenAICodexProbeSnapshot() error = %v"
,
err
)
}
if
len
(
updates
)
==
0
{
t
.
Fatal
(
"expected codex probe updates from 429 headers"
)
}
if
resetAt
==
nil
{
t
.
Fatal
(
"expected resetAt from exhausted codex headers"
)
}
}
func
TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
repo
:=
&
accountUsageCodexProbeRepo
{
repo
:=
&
accountUsageCodexProbeRepo
{
...
@@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
...
@@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
}
}
svc
:=
&
AccountUsageService
{
accountRepo
:
repo
}
svc
:=
&
AccountUsageService
{
accountRepo
:
repo
}
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
svc
.
persistOpenAICodexProbeSnapshot
(
321
,
map
[
string
]
any
{
svc
.
persistOpenAICodexProbeSnapshot
(
321
,
map
[
string
]
any
{
"codex_7d_used_percent"
:
100.0
,
"codex_7d_used_percent"
:
100.0
,
"codex_7d_reset_at"
:
resetAt
.
Format
(
time
.
RFC3339
),
"codex_7d_reset_at"
:
time
.
Now
()
.
Add
(
2
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
.
Format
(
time
.
RFC3339
),
}
,
&
resetAt
)
})
select
{
select
{
case
updates
:=
<-
repo
.
updateExtraCh
:
case
updates
:=
<-
repo
.
updateExtraCh
:
...
@@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
...
@@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
t
.
Fatalf
(
"codex_7d_used_percent = %v, want 100"
,
got
)
t
.
Fatalf
(
"codex_7d_used_percent = %v, want 100"
,
got
)
}
}
case
<-
time
.
After
(
2
*
time
.
Second
)
:
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"
waiting for
codex
probe
extra
persistence timed out
"
)
t
.
Fatal
(
"
等待
codex
探测快照写入
extra
超时
"
)
}
}
select
{
select
{
case
got
:=
<-
repo
.
rateLimitCh
:
case
got
:=
<-
repo
.
rateLimitCh
:
if
got
.
Before
(
resetAt
.
Add
(
-
time
.
Second
))
||
got
.
After
(
resetAt
.
Add
(
time
.
Second
))
{
t
.
Fatalf
(
"不应将探测快照写入运行时限流状态: %v"
,
got
)
t
.
Fatalf
(
"rate limit resetAt = %v, want around %v"
,
got
,
resetAt
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
case
<-
time
.
After
(
2
*
time
.
Second
)
:
}
t
.
Fatal
(
"waiting for codex probe rate limit persistence timed out"
)
func
TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit
(
t
*
testing
.
T
)
{
t
.
Parallel
()
resetAt
:=
time
.
Now
()
.
Add
(
6
*
24
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
repo
:=
&
accountUsageCodexProbeRepo
{
rateLimitCh
:
make
(
chan
time
.
Time
,
1
),
}
svc
:=
&
AccountUsageService
{
accountRepo
:
repo
}
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"codex_5h_used_percent"
:
1.0
,
"codex_5h_reset_at"
:
time
.
Now
()
.
Add
(
2
*
time
.
Hour
)
.
UTC
()
.
Truncate
(
time
.
Second
)
.
Format
(
time
.
RFC3339
),
"codex_7d_used_percent"
:
100.0
,
"codex_7d_reset_at"
:
resetAt
.
Format
(
time
.
RFC3339
),
},
}
usage
,
err
:=
svc
.
getOpenAIUsage
(
context
.
Background
(),
account
)
if
err
!=
nil
{
t
.
Fatalf
(
"getOpenAIUsage() error = %v"
,
err
)
}
if
usage
.
SevenDay
==
nil
||
usage
.
SevenDay
.
Utilization
!=
100.0
{
t
.
Fatalf
(
"预期 7 天用量仍然可见,实际为 %#v"
,
usage
.
SevenDay
)
}
if
account
.
RateLimitResetAt
!=
nil
{
t
.
Fatalf
(
"不应让已耗尽的 codex extra 改写运行时限流状态: %v"
,
account
.
RateLimitResetAt
)
}
select
{
case
got
:=
<-
repo
.
rateLimitCh
:
t
.
Fatalf
(
"不应将已耗尽的 codex extra 持久化为运行时限流状态: %v"
,
got
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
}
}
...
...
backend/internal/service/account_websearch_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestGetWebSearchEmulationMode_Enabled
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"enabled"
},
}
require
.
Equal
(
t
,
WebSearchModeEnabled
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_Disabled
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"disabled"
},
}
require
.
Equal
(
t
,
WebSearchModeDisabled
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_Default
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"default"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_UnknownString
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"unknown"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_OldBoolTrue
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
true
},
}
// bool true → tolerant fallback → enabled (not default)
require
.
Equal
(
t
,
WebSearchModeEnabled
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_OldBoolFalse
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
false
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NilAccount
(
t
*
testing
.
T
)
{
var
a
*
Account
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NilExtra
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
nil
,
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_MissingField
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NonAnthropicPlatform
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"enabled"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
func
TestGetWebSearchEmulationMode_NonAPIKeyType
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
"enabled"
},
}
require
.
Equal
(
t
,
WebSearchModeDefault
,
a
.
GetWebSearchEmulationMode
())
}
backend/internal/service/admin_service.go
View file @
0b746501
...
@@ -1470,10 +1470,6 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
...
@@ -1470,10 +1470,6 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
err
return
nil
,
0
,
err
}
}
now
:=
time
.
Now
()
for
i
:=
range
accounts
{
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
&
accounts
[
i
],
now
)
}
return
accounts
,
result
.
Total
,
nil
return
accounts
,
result
.
Total
,
nil
}
}
...
...
backend/internal/service/admin_service_apikey_test.go
View file @
0b746501
...
@@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
...
@@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
func
(
s
*
userRepoStubForGroupUpdate
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
userRepoStubForGroupUpdate
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
panic
(
"unexpected"
)
}
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type
apiKeyRepoStubForGroupUpdate
struct
{
type
apiKeyRepoStubForGroupUpdate
struct
{
...
@@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
...
@@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
UpdateGroupIDByUserAndGroup
(
context
.
Context
,
int64
,
int64
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
...
@@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
...
@@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
GetRateLimitData
(
context
.
Context
,
int64
)
(
*
APIKeyRateLimitData
,
error
)
{
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
GetRateLimitData
(
context
.
Context
,
int64
)
(
*
APIKeyRateLimitData
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
UpdateGroupIDByUserAndGroup
(
context
.
Context
,
int64
,
int64
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type
groupRepoStubForGroupUpdate
struct
{
type
groupRepoStubForGroupUpdate
struct
{
...
...
backend/internal/service/admin_service_clear_error_test.go
View file @
0b746501
...
@@ -12,12 +12,12 @@ import (
...
@@ -12,12 +12,12 @@ import (
type
accountRepoStubForClearAccountError
struct
{
type
accountRepoStubForClearAccountError
struct
{
mockAccountRepoForGemini
mockAccountRepoForGemini
account
*
Account
account
*
Account
clearErrorCalls
int
clearErrorCalls
int
clearRateLimitCalls
int
clearRateLimitCalls
int
clearAntigravityCalls
int
clearAntigravityCalls
int
clearModelRateLimitCalls
int
clearModelRateLimitCalls
int
clearTempUnschedCalls
int
clearTempUnschedCalls
int
}
}
func
(
r
*
accountRepoStubForClearAccountError
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
func
(
r
*
accountRepoStubForClearAccountError
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
...
@@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
...
@@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
repo
:=
&
accountRepoStubForClearAccountError
{
repo
:=
&
accountRepoStubForClearAccountError
{
account
:
&
Account
{
account
:
&
Account
{
ID
:
31
,
ID
:
31
,
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
Status
:
StatusError
,
Status
:
StatusError
,
ErrorMessage
:
"refresh failed"
,
ErrorMessage
:
"refresh failed"
,
RateLimitResetAt
:
&
resetAt
,
RateLimitResetAt
:
&
resetAt
,
TempUnschedulableUntil
:
&
until
,
TempUnschedulableUntil
:
&
until
,
TempUnschedulableReason
:
"missing refresh token"
,
TempUnschedulableReason
:
"missing refresh token"
,
},
},
}
}
...
...
backend/internal/service/api_key_auth_cache.go
View file @
0b746501
...
@@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
...
@@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
Role
string
`json:"role"`
Role
string
`json:"role"`
Balance
float64
`json:"balance"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
Concurrency
int
`json:"concurrency"`
// Balance notification fields (required for CheckBalanceAfterDeduction)
Email
string
`json:"email"`
Username
string
`json:"username"`
BalanceNotifyEnabled
bool
`json:"balance_notify_enabled"`
BalanceNotifyThresholdType
string
`json:"balance_notify_threshold_type"`
BalanceNotifyThreshold
*
float64
`json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
`json:"balance_notify_extra_emails,omitempty"`
TotalRecharged
float64
`json:"total_recharged"`
}
}
// APIKeyAuthGroupSnapshot 分组快照
// APIKeyAuthGroupSnapshot 分组快照
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
0b746501
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"encoding/hex"
"encoding/hex"
"errors"
"errors"
"fmt"
"fmt"
"log/slog"
"math/rand/v2"
"math/rand/v2"
"time"
"time"
...
@@ -13,7 +14,7 @@ import (
...
@@ -13,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
"github.com/dgraph-io/ristretto"
)
)
const
apiKeyAuthSnapshotVersion
=
3
const
apiKeyAuthSnapshotVersion
=
5
// v5: added TotalRecharged for percentage threshold
type
apiKeyAuthCacheConfig
struct
{
type
apiKeyAuthCacheConfig
struct
{
l1Size
int
l1Size
int
...
@@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
...
@@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
s
.
authCacheL1
.
Del
(
cacheKey
)
s
.
authCacheL1
.
Del
(
cacheKey
)
});
err
!=
nil
{
});
err
!=
nil
{
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println
(
"[Service] Warning:
failed to start auth cache invalidation subscriber
:
"
,
err
.
Error
()
)
slog
.
Warn
(
"
failed to start auth cache invalidation subscriber"
,
"
err
or"
,
err
)
}
}
}
}
...
@@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
...
@@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
RateLimit1d
:
apiKey
.
RateLimit1d
,
RateLimit1d
:
apiKey
.
RateLimit1d
,
RateLimit7d
:
apiKey
.
RateLimit7d
,
RateLimit7d
:
apiKey
.
RateLimit7d
,
User
:
APIKeyAuthUserSnapshot
{
User
:
APIKeyAuthUserSnapshot
{
ID
:
apiKey
.
User
.
ID
,
ID
:
apiKey
.
User
.
ID
,
Status
:
apiKey
.
User
.
Status
,
Status
:
apiKey
.
User
.
Status
,
Role
:
apiKey
.
User
.
Role
,
Role
:
apiKey
.
User
.
Role
,
Balance
:
apiKey
.
User
.
Balance
,
Balance
:
apiKey
.
User
.
Balance
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
Email
:
apiKey
.
User
.
Email
,
Username
:
apiKey
.
User
.
Username
,
BalanceNotifyEnabled
:
apiKey
.
User
.
BalanceNotifyEnabled
,
BalanceNotifyThresholdType
:
apiKey
.
User
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
apiKey
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
apiKey
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
apiKey
.
User
.
TotalRecharged
,
},
},
}
}
if
apiKey
.
Group
!=
nil
{
if
apiKey
.
Group
!=
nil
{
...
@@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
...
@@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
RateLimit1d
:
snapshot
.
RateLimit1d
,
RateLimit1d
:
snapshot
.
RateLimit1d
,
RateLimit7d
:
snapshot
.
RateLimit7d
,
RateLimit7d
:
snapshot
.
RateLimit7d
,
User
:
&
User
{
User
:
&
User
{
ID
:
snapshot
.
User
.
ID
,
ID
:
snapshot
.
User
.
ID
,
Status
:
snapshot
.
User
.
Status
,
Status
:
snapshot
.
User
.
Status
,
Role
:
snapshot
.
User
.
Role
,
Role
:
snapshot
.
User
.
Role
,
Balance
:
snapshot
.
User
.
Balance
,
Balance
:
snapshot
.
User
.
Balance
,
Concurrency
:
snapshot
.
User
.
Concurrency
,
Concurrency
:
snapshot
.
User
.
Concurrency
,
Email
:
snapshot
.
User
.
Email
,
Username
:
snapshot
.
User
.
Username
,
BalanceNotifyEnabled
:
snapshot
.
User
.
BalanceNotifyEnabled
,
BalanceNotifyThresholdType
:
snapshot
.
User
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
snapshot
.
User
.
BalanceNotifyThreshold
,
BalanceNotifyExtraEmails
:
snapshot
.
User
.
BalanceNotifyExtraEmails
,
TotalRecharged
:
snapshot
.
User
.
TotalRecharged
,
},
},
}
}
if
snapshot
.
Group
!=
nil
{
if
snapshot
.
Group
!=
nil
{
...
...
backend/internal/service/auth_service_register_test.go
View file @
0b746501
...
@@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
...
@@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return
nil
return
nil
}
}
func
(
s
*
emailCacheStub
)
GetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailCacheStub
)
SetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
DeleteNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
return
nil
,
nil
return
nil
,
nil
}
}
...
@@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
...
@@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
return
nil
return
nil
}
}
func
(
s
*
emailCacheStub
)
GetNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
s
*
emailCacheStub
)
IncrNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
,
window
time
.
Duration
)
(
int64
,
error
)
{
return
0
,
nil
}
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
JWT
:
config
.
JWTConfig
{
...
...
backend/internal/service/balance_notify_check_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
// in-memory settings repo and a non-nil emailService so that the guard-clause
// nil-checks pass. The emailService is intentionally minimal — tests must
// avoid crossing scenarios that would actually dispatch emails.
func
newBalanceNotifyServiceForTest
()
(
*
BalanceNotifyService
,
*
mockSettingRepo
)
{
repo
:=
newMockSettingRepo
()
// EmailService is a concrete type; construct with the same repo so that
// any accidental fallback reads still succeed. Tests should not trigger a
// crossing that reaches SendEmail.
email
:=
NewEmailService
(
repo
,
nil
)
return
NewBalanceNotifyService
(
email
,
repo
,
nil
),
repo
}
// ---------- guard clauses ----------
func
TestCheckBalanceAfterDeduction_NilUser
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
// Should not panic.
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
nil
,
100
,
50
)
}
func
TestCheckBalanceAfterDeduction_UserNotifyDisabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"10"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
false
}
// Even with a crossing, disabled flag short-circuits.
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_GlobalDisabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"false"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
}
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_ThresholdZero
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"0"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
}
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_UserThresholdOverride
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"100"
// global default
customThreshold
:=
5.0
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
,
BalanceNotifyThreshold
:
&
customThreshold
,
}
// User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
// cross 5, so nothing fires (verified by absence of panic).
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
20
,
15
)
}
func
TestCheckBalanceAfterDeduction_NoCrossingNotFired
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"10"
u
:=
&
User
{
ID
:
1
,
BalanceNotifyEnabled
:
true
}
// 100 -> 95, both remain above threshold=10, no crossing.
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
100
,
5
)
// 5 -> 3, both already below threshold, no crossing (only fires on first
// cross from above-to-below).
s
.
CheckBalanceAfterDeduction
(
context
.
Background
(),
u
,
5
,
2
)
}
// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
func
TestCheckAccountQuotaAfterIncrement_NilAccount
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
// Should not panic.
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
nil
,
10
,
nil
)
}
func
TestCheckAccountQuotaAfterIncrement_ZeroCost
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
a
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
}
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
a
,
0
,
nil
)
}
func
TestCheckAccountQuotaAfterIncrement_NegativeCost
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
a
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
}
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
a
,
-
5
,
nil
)
}
func
TestCheckAccountQuotaAfterIncrement_GlobalDisabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyAccountQuotaNotifyEnabled
]
=
"false"
a
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"quota_notify_daily_enabled"
:
true
,
"quota_notify_daily_threshold"
:
100.0
,
"quota_daily_limit"
:
1000.0
,
"quota_daily_used"
:
950.0
,
},
}
// Global disabled → no processing even if a dim would cross.
s
.
CheckAccountQuotaAfterIncrement
(
context
.
Background
(),
a
,
100
,
nil
)
}
// ---------- sanity: internal helpers still work ----------
func
TestGetBalanceNotifyConfig_AllFields
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"12.5"
repo
.
data
[
SettingKeyBalanceLowNotifyRechargeURL
]
=
"https://example.com/pay"
enabled
,
threshold
,
url
:=
s
.
getBalanceNotifyConfig
(
context
.
Background
())
require
.
True
(
t
,
enabled
)
require
.
Equal
(
t
,
12.5
,
threshold
)
require
.
Equal
(
t
,
"https://example.com/pay"
,
url
)
}
func
TestGetBalanceNotifyConfig_Disabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"false"
enabled
,
_
,
_
:=
s
.
getBalanceNotifyConfig
(
context
.
Background
())
require
.
False
(
t
,
enabled
)
}
func
TestGetBalanceNotifyConfig_InvalidThreshold
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeyBalanceLowNotifyEnabled
]
=
"true"
repo
.
data
[
SettingKeyBalanceLowNotifyThreshold
]
=
"not-a-number"
enabled
,
threshold
,
_
:=
s
.
getBalanceNotifyConfig
(
context
.
Background
())
require
.
True
(
t
,
enabled
)
require
.
Equal
(
t
,
0.0
,
threshold
)
}
func
TestIsAccountQuotaNotifyEnabled
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
// Missing key → false
require
.
False
(
t
,
s
.
isAccountQuotaNotifyEnabled
(
context
.
Background
()))
// Explicit "false"
repo
.
data
[
SettingKeyAccountQuotaNotifyEnabled
]
=
"false"
require
.
False
(
t
,
s
.
isAccountQuotaNotifyEnabled
(
context
.
Background
()))
// Explicit "true"
repo
.
data
[
SettingKeyAccountQuotaNotifyEnabled
]
=
"true"
require
.
True
(
t
,
s
.
isAccountQuotaNotifyEnabled
(
context
.
Background
()))
}
func
TestGetSiteName_FallsBackToDefault
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
name
:=
s
.
getSiteName
(
context
.
Background
())
require
.
Equal
(
t
,
defaultSiteName
,
name
)
}
func
TestGetSiteName_Configured
(
t
*
testing
.
T
)
{
s
,
repo
:=
newBalanceNotifyServiceForTest
()
repo
.
data
[
SettingKeySiteName
]
=
"My Site"
require
.
Equal
(
t
,
"My Site"
,
s
.
getSiteName
(
context
.
Background
()))
}
// ---------- crossedDownward ----------
func
TestCrossedDownward_CrossesBelow
(
t
*
testing
.
T
)
{
// oldBalance > threshold, newBalance < threshold → true
require
.
True
(
t
,
crossedDownward
(
100
,
5
,
10
))
}
func
TestCrossedDownward_ExactlyAtThreshold
(
t
*
testing
.
T
)
{
// oldBalance > threshold, newBalance == threshold → false (not below)
require
.
False
(
t
,
crossedDownward
(
100
,
10
,
10
))
}
func
TestCrossedDownward_OldExactlyAtThreshold_NewBelow
(
t
*
testing
.
T
)
{
// oldBalance == threshold, newBalance < threshold → true
// (at-or-above → below counts as a crossing)
require
.
True
(
t
,
crossedDownward
(
10
,
5
,
10
))
}
func
TestCrossedDownward_AlreadyBelow
(
t
*
testing
.
T
)
{
// oldBalance < threshold → false (already below, no new crossing)
require
.
False
(
t
,
crossedDownward
(
5
,
3
,
10
))
}
func
TestCrossedDownward_BothAbove
(
t
*
testing
.
T
)
{
// oldBalance > threshold, newBalance > threshold → false (no crossing)
require
.
False
(
t
,
crossedDownward
(
100
,
50
,
10
))
}
func
TestCrossedDownward_ZeroThreshold
(
t
*
testing
.
T
)
{
// threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
// Typical case: positive balances should not fire when threshold is 0.
require
.
False
(
t
,
crossedDownward
(
10
,
5
,
0
))
require
.
False
(
t
,
crossedDownward
(
0
,
0
,
0
))
}
func
TestCrossedDownward_ZeroThreshold_NegativeNew
(
t
*
testing
.
T
)
{
// Edge case: newBalance goes negative with threshold=0.
require
.
True
(
t
,
crossedDownward
(
5
,
-
1
,
0
))
}
func
TestCrossedDownward_NegativeValues
(
t
*
testing
.
T
)
{
// Both already negative, threshold is positive → no crossing (already below).
require
.
False
(
t
,
crossedDownward
(
-
5
,
-
10
,
10
))
}
func
TestCrossedDownward_LargeDecrement
(
t
*
testing
.
T
)
{
// A single large deduction crosses the threshold.
require
.
True
(
t
,
crossedDownward
(
1000
,
0.5
,
100
))
}
func
TestCrossedDownward_SmallDecrement_NoCrossing
(
t
*
testing
.
T
)
{
// A tiny deduction stays above threshold.
require
.
False
(
t
,
crossedDownward
(
100
,
99.99
,
10
))
}
// ---------- checkQuotaDimCrossings ----------
func
TestCheckQuotaDimCrossings_NoDimensions
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// Empty dims → no crossing, no panic.
s
.
checkQuotaDimCrossings
(
account
,
nil
,
10
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
s
.
checkQuotaDimCrossings
(
account
,
[]
quotaDim
{},
10
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_DisabledDimension
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
false
,
// disabled
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
950
,
limit
:
1000
,
},
}
// Disabled dimension should be skipped even if crossing would occur.
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_ZeroThresholdSkipped
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
0
,
// zero threshold
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
950
,
limit
:
1000
,
},
}
// Zero threshold → skipped.
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
300
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
800
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
// Negative resolved threshold → skipped.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
1200
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
950
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
// currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimWeekly
,
enabled
:
true
,
threshold
:
30
,
thresholdType
:
thresholdTypePercentage
,
currentUsed
:
500
,
limit
:
1000
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_ZeroLimit_Skipped
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// limit=0 → resolvedThreshold returns 0 → skipped.
dims
:=
[]
quotaDim
{
{
name
:
quotaDimTotal
,
enabled
:
true
,
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
50
,
limit
:
0
,
},
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
func
TestCheckQuotaDimCrossings_MultipleDims_MixedResults
(
t
*
testing
.
T
)
{
s
,
_
:=
newBalanceNotifyServiceForTest
()
account
:=
&
Account
{
ID
:
1
,
Name
:
"test"
,
Platform
:
PlatformAnthropic
}
// dim1: no crossing (both below effective threshold)
// dim2: disabled (skipped)
// dim3: zero threshold (skipped)
dims
:=
[]
quotaDim
{
{
name
:
quotaDimDaily
,
enabled
:
true
,
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
300
,
// oldUsed=250, effectiveThreshold=600, both below
limit
:
1000
,
},
{
name
:
quotaDimWeekly
,
enabled
:
false
,
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
900
,
limit
:
1000
,
},
{
name
:
quotaDimTotal
,
enabled
:
true
,
threshold
:
0
,
thresholdType
:
thresholdTypeFixed
,
currentUsed
:
500
,
limit
:
1000
,
},
}
// None should trigger. No panic expected.
s
.
checkQuotaDimCrossings
(
account
,
dims
,
50
,
[]
string
{
"admin@example.com"
},
"TestSite"
)
}
backend/internal/service/balance_notify_email_body_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// These tests guard against fmt.Sprintf arg-count mismatches in the email
// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
// the output, which these assertions will catch.
// ---------- buildBalanceLowEmailBody ----------
func
TestBuildBalanceLowEmailBody_ContainsRequiredFields
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"Alice"
,
3.14
,
10.0
,
"MySite"
,
""
)
// All substituted values should appear in the output.
require
.
Contains
(
t
,
body
,
"MySite"
)
require
.
Contains
(
t
,
body
,
"Alice"
)
require
.
Contains
(
t
,
body
,
"$3.14"
)
require
.
Contains
(
t
,
body
,
"$10.00"
)
// No fmt.Sprintf format error markers.
require
.
NotContains
(
t
,
body
,
"%!"
)
require
.
NotContains
(
t
,
body
,
"MISSING"
)
require
.
NotContains
(
t
,
body
,
"EXTRA"
)
}
func
TestBuildBalanceLowEmailBody_WithRechargeURL
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"Bob"
,
5.0
,
20.0
,
"Site"
,
"https://example.com/pay"
)
// The recharge anchor element should appear with the URL.
require
.
Contains
(
t
,
body
,
`href="https://example.com/pay"`
)
require
.
Contains
(
t
,
body
,
"立即充值"
)
require
.
NotContains
(
t
,
body
,
"%!"
)
}
func
TestBuildBalanceLowEmailBody_RechargeURLEscaped
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
// Try a URL with characters that need HTML escaping.
body
:=
s
.
buildBalanceLowEmailBody
(
"u"
,
1.0
,
5.0
,
"Site"
,
`https://example.com/?a=1&b=<script>`
)
// `&` and `<` should be escaped in the href.
require
.
Contains
(
t
,
body
,
"&"
)
require
.
Contains
(
t
,
body
,
"<script>"
)
require
.
NotContains
(
t
,
body
,
"<script>"
)
}
func
TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"u"
,
1.0
,
5.0
,
"Site"
,
""
)
// The anchor element should not be rendered (style class may still appear).
require
.
NotContains
(
t
,
body
,
`<a href`
)
require
.
NotContains
(
t
,
body
,
"立即充值"
)
}
// ---------- buildQuotaAlertEmailBody ----------
func
TestBuildQuotaAlertEmailBody_AllFieldsPresent
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
42
,
// accountID
"acc-foo"
,
// accountName
"anthropic"
,
// platform
"日限额 / Daily"
,
// dimLabel
750.50
,
// used
1000.0
,
// limit
249.50
,
// remaining
"$249.50"
,
// thresholdDisplay
"MySite"
,
// siteName
)
require
.
Contains
(
t
,
body
,
"MySite"
)
require
.
Contains
(
t
,
body
,
"#42"
)
require
.
Contains
(
t
,
body
,
"acc-foo"
)
require
.
Contains
(
t
,
body
,
"anthropic"
)
require
.
Contains
(
t
,
body
,
"Daily"
)
require
.
Contains
(
t
,
body
,
"$750.50"
)
require
.
Contains
(
t
,
body
,
"$1000.00"
)
require
.
Contains
(
t
,
body
,
"$249.50"
)
// No format error markers.
require
.
NotContains
(
t
,
body
,
"%!"
)
require
.
NotContains
(
t
,
body
,
"MISSING"
)
require
.
NotContains
(
t
,
body
,
"EXTRA"
)
}
func
TestBuildQuotaAlertEmailBody_UnlimitedDisplay
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"dim"
,
100.0
,
0.0
,
// limit=0 triggers unlimited branch
0.0
,
"30%"
,
"Site"
,
)
require
.
Contains
(
t
,
body
,
"无限制"
)
require
.
Contains
(
t
,
body
,
"Unlimited"
)
}
func
TestBuildQuotaAlertEmailBody_PercentageThresholdDisplay
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"dim"
,
700.0
,
1000.0
,
300.0
,
"30%"
,
// percentage-formatted threshold
"Site"
,
)
require
.
Contains
(
t
,
body
,
"30%"
)
require
.
NotContains
(
t
,
body
,
"%!"
)
}
func
TestBuildQuotaAlertEmailBody_RemainingClampedAtZero
(
t
*
testing
.
T
)
{
// Even though caller is responsible for clamping, this test documents the
// display behavior with remaining=0.
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"dim"
,
1500.0
,
1000.0
,
0.0
,
// used > limit (over-quota)
"$100.00"
,
"Site"
,
)
require
.
Contains
(
t
,
body
,
"$0.00"
)
}
// ---------- sanity checks on the CSS `%%` escape ----------
func
TestBuildBalanceLowEmailBody_NoCSSFormatError
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildBalanceLowEmailBody
(
"u"
,
1.0
,
5.0
,
"Site"
,
""
)
// CSS `linear-gradient(135deg, #f59e0b 0%, #d97706 100%)` should appear with
// literal percent signs (from the %% escape in the template).
require
.
True
(
t
,
strings
.
Contains
(
body
,
"0%"
)
&&
strings
.
Contains
(
body
,
"100%"
),
"CSS gradient percentages not rendered; got: %s"
,
body
)
}
func
TestBuildQuotaAlertEmailBody_NoCSSFormatError
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
body
:=
s
.
buildQuotaAlertEmailBody
(
1
,
"n"
,
"p"
,
"d"
,
0
,
0
,
0
,
"$0.00"
,
"Site"
)
require
.
True
(
t
,
strings
.
Contains
(
body
,
"0%"
)
&&
strings
.
Contains
(
body
,
"100%"
),
"CSS gradient percentages not rendered; got: %s"
,
body
)
}
backend/internal/service/balance_notify_service.go
0 → 100644
View file @
0b746501
package
service
import
(
"context"
"fmt"
"html"
"log/slog"
"strconv"
"strings"
"time"
)
const
(
emailSendTimeout
=
30
*
time
.
Second
// Threshold type values
thresholdTypeFixed
=
"fixed"
thresholdTypePercentage
=
"percentage"
// Quota dimension labels
quotaDimDaily
=
"daily"
quotaDimWeekly
=
"weekly"
quotaDimTotal
=
"total"
defaultSiteName
=
"Sub2API"
)
// quotaDimLabels maps dimension names to display labels.
var
quotaDimLabels
=
map
[
string
]
string
{
quotaDimDaily
:
"日限额 / Daily"
,
quotaDimWeekly
:
"周限额 / Weekly"
,
quotaDimTotal
:
"总限额 / Total"
,
}
// AccountQuotaReader provides read access to account quota data.
type
AccountQuotaReader
interface
{
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
}
// BalanceNotifyService handles balance and quota threshold notifications.
type
BalanceNotifyService
struct
{
emailService
*
EmailService
settingRepo
SettingRepository
accountRepo
AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
func
NewBalanceNotifyService
(
emailService
*
EmailService
,
settingRepo
SettingRepository
,
accountRepo
AccountQuotaReader
)
*
BalanceNotifyService
{
return
&
BalanceNotifyService
{
emailService
:
emailService
,
settingRepo
:
settingRepo
,
accountRepo
:
accountRepo
,
}
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func
resolveBalanceThreshold
(
threshold
float64
,
thresholdType
string
,
totalRecharged
float64
)
float64
{
if
thresholdType
==
thresholdTypePercentage
&&
totalRecharged
>
0
{
return
totalRecharged
*
threshold
/
100
}
return
threshold
}
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
func
(
s
*
BalanceNotifyService
)
CheckBalanceAfterDeduction
(
ctx
context
.
Context
,
user
*
User
,
oldBalance
,
cost
float64
)
{
if
!
s
.
canNotifyBalance
(
user
)
{
return
}
effectiveThreshold
,
rechargeURL
,
ok
:=
s
.
resolveUserEffectiveThreshold
(
ctx
,
user
)
if
!
ok
{
return
}
newBalance
:=
oldBalance
-
cost
if
!
crossedDownward
(
oldBalance
,
newBalance
,
effectiveThreshold
)
{
return
}
s
.
dispatchBalanceLowEmail
(
ctx
,
user
,
newBalance
,
effectiveThreshold
,
rechargeURL
)
}
// canNotifyBalance checks nil guards and user-level toggle.
func
(
s
*
BalanceNotifyService
)
canNotifyBalance
(
user
*
User
)
bool
{
if
user
==
nil
||
s
.
emailService
==
nil
||
s
.
settingRepo
==
nil
{
return
false
}
return
user
.
BalanceNotifyEnabled
}
// resolveUserEffectiveThreshold reads global + user config, returns the effective threshold.
// Returns ok=false when notifications should be skipped.
func
(
s
*
BalanceNotifyService
)
resolveUserEffectiveThreshold
(
ctx
context
.
Context
,
user
*
User
)
(
effectiveThreshold
float64
,
rechargeURL
string
,
ok
bool
)
{
globalEnabled
,
globalThreshold
,
rechargeURL
:=
s
.
getBalanceNotifyConfig
(
ctx
)
if
!
globalEnabled
{
return
0
,
""
,
false
}
threshold
:=
globalThreshold
if
user
.
BalanceNotifyThreshold
!=
nil
{
threshold
=
*
user
.
BalanceNotifyThreshold
}
if
threshold
<=
0
{
return
0
,
""
,
false
}
effectiveThreshold
=
resolveBalanceThreshold
(
threshold
,
user
.
BalanceNotifyThresholdType
,
user
.
TotalRecharged
)
if
effectiveThreshold
<=
0
{
return
0
,
""
,
false
}
return
effectiveThreshold
,
rechargeURL
,
true
}
// crossedDownward returns true when oldV was at-or-above threshold but newV dropped below it.
func
crossedDownward
(
oldV
,
newV
,
threshold
float64
)
bool
{
return
oldV
>=
threshold
&&
newV
<
threshold
}
// dispatchBalanceLowEmail collects recipients and sends the alert in a goroutine.
func
(
s
*
BalanceNotifyService
)
dispatchBalanceLowEmail
(
ctx
context
.
Context
,
user
*
User
,
newBalance
,
threshold
float64
,
rechargeURL
string
)
{
siteName
:=
s
.
getSiteName
(
ctx
)
recipients
:=
s
.
collectBalanceNotifyRecipients
(
user
)
slog
.
Info
(
"CheckBalanceAfterDeduction: sending notification"
,
"user_id"
,
user
.
ID
,
"recipients"
,
recipients
,
"new_balance"
,
newBalance
,
"threshold"
,
threshold
)
go
func
()
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in balance notification"
,
"recover"
,
r
)
}
}()
s
.
sendBalanceLowEmails
(
recipients
,
user
.
Username
,
user
.
Email
,
newBalance
,
threshold
,
siteName
,
rechargeURL
)
}()
}
// quotaDim describes one quota dimension for notification checking.
type
quotaDim
struct
{
name
string
enabled
bool
threshold
float64
thresholdType
string
// "fixed" (default) or "percentage"
currentUsed
float64
limit
float64
}
// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point.
// The threshold represents how much quota REMAINS when the alert fires:
// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400)
// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%)
func
(
d
quotaDim
)
resolvedThreshold
()
float64
{
if
d
.
limit
<=
0
{
return
0
}
if
d
.
thresholdType
==
thresholdTypePercentage
{
return
d
.
limit
*
(
1
-
d
.
threshold
/
100
)
}
return
d
.
limit
-
d
.
threshold
}
// buildQuotaDims returns the three quota dimensions for notification checking.
func
buildQuotaDims
(
account
*
Account
)
[]
quotaDim
{
return
[]
quotaDim
{
{
quotaDimDaily
,
account
.
GetQuotaNotifyDailyEnabled
(),
account
.
GetQuotaNotifyDailyThreshold
(),
account
.
GetQuotaNotifyDailyThresholdType
(),
account
.
GetQuotaDailyUsed
(),
account
.
GetQuotaDailyLimit
()},
{
quotaDimWeekly
,
account
.
GetQuotaNotifyWeeklyEnabled
(),
account
.
GetQuotaNotifyWeeklyThreshold
(),
account
.
GetQuotaNotifyWeeklyThresholdType
(),
account
.
GetQuotaWeeklyUsed
(),
account
.
GetQuotaWeeklyLimit
()},
{
quotaDimTotal
,
account
.
GetQuotaNotifyTotalEnabled
(),
account
.
GetQuotaNotifyTotalThreshold
(),
account
.
GetQuotaNotifyTotalThresholdType
(),
account
.
GetQuotaUsed
(),
account
.
GetQuotaLimit
()},
}
}
// buildQuotaDimsFromState builds quota dimensions using DB transaction state instead of account snapshot.
// Notification settings (enabled, threshold, thresholdType) come from the account; usage values from quotaState.
func
buildQuotaDimsFromState
(
account
*
Account
,
state
*
AccountQuotaState
)
[]
quotaDim
{
return
[]
quotaDim
{
{
quotaDimDaily
,
account
.
GetQuotaNotifyDailyEnabled
(),
account
.
GetQuotaNotifyDailyThreshold
(),
account
.
GetQuotaNotifyDailyThresholdType
(),
state
.
DailyUsed
,
state
.
DailyLimit
},
{
quotaDimWeekly
,
account
.
GetQuotaNotifyWeeklyEnabled
(),
account
.
GetQuotaNotifyWeeklyThreshold
(),
account
.
GetQuotaNotifyWeeklyThresholdType
(),
state
.
WeeklyUsed
,
state
.
WeeklyLimit
},
{
quotaDimTotal
,
account
.
GetQuotaNotifyTotalEnabled
(),
account
.
GetQuotaNotifyTotalThreshold
(),
account
.
GetQuotaNotifyTotalThresholdType
(),
state
.
TotalUsed
,
state
.
TotalLimit
},
}
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// When quotaState is non-nil (from DB transaction RETURNING), it is used directly for threshold
// checking, avoiding a separate DB read. Otherwise it falls back to fetching fresh account data.
func
(
s
*
BalanceNotifyService
)
CheckAccountQuotaAfterIncrement
(
ctx
context
.
Context
,
account
*
Account
,
cost
float64
,
quotaState
*
AccountQuotaState
)
{
if
account
==
nil
||
s
.
emailService
==
nil
||
s
.
settingRepo
==
nil
||
cost
<=
0
{
return
}
if
!
s
.
isAccountQuotaNotifyEnabled
(
ctx
)
{
return
}
adminEmails
:=
s
.
getAccountQuotaNotifyEmails
(
ctx
)
if
len
(
adminEmails
)
==
0
{
return
}
siteName
:=
s
.
getSiteName
(
ctx
)
var
dims
[]
quotaDim
if
quotaState
!=
nil
{
dims
=
buildQuotaDimsFromState
(
account
,
quotaState
)
}
else
{
freshAccount
:=
s
.
fetchFreshAccount
(
ctx
,
account
)
dims
=
buildQuotaDims
(
freshAccount
)
account
=
freshAccount
// use fresh data for alert metadata
}
s
.
checkQuotaDimCrossings
(
account
,
dims
,
cost
,
adminEmails
,
siteName
)
}
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
func
(
s
*
BalanceNotifyService
)
fetchFreshAccount
(
ctx
context
.
Context
,
snapshot
*
Account
)
*
Account
{
if
s
.
accountRepo
==
nil
{
return
snapshot
}
fresh
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
snapshot
.
ID
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to fetch fresh account for quota notify, using snapshot"
,
"account_id"
,
snapshot
.
ID
,
"error"
,
err
)
return
snapshot
}
return
fresh
}
// checkQuotaDimCrossings iterates pre-built quota dimensions and sends alerts for threshold crossings.
// Pre-increment value is reconstructed as currentUsed - cost to detect the crossing moment.
func
(
s
*
BalanceNotifyService
)
checkQuotaDimCrossings
(
account
*
Account
,
dims
[]
quotaDim
,
cost
float64
,
adminEmails
[]
string
,
siteName
string
)
{
for
_
,
dim
:=
range
dims
{
if
!
dim
.
enabled
||
dim
.
threshold
<=
0
{
continue
}
effectiveThreshold
:=
dim
.
resolvedThreshold
()
if
effectiveThreshold
<=
0
{
continue
}
newUsed
:=
dim
.
currentUsed
oldUsed
:=
dim
.
currentUsed
-
cost
if
oldUsed
<
effectiveThreshold
&&
newUsed
>=
effectiveThreshold
{
s
.
asyncSendQuotaAlert
(
adminEmails
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
dim
,
newUsed
,
effectiveThreshold
,
siteName
)
}
}
}
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
func
(
s
*
BalanceNotifyService
)
asyncSendQuotaAlert
(
adminEmails
[]
string
,
accountID
int64
,
accountName
,
platform
string
,
dim
quotaDim
,
newUsed
,
effectiveThreshold
float64
,
siteName
string
)
{
go
func
()
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in quota notification"
,
"recover"
,
r
)
}
}()
s
.
sendQuotaAlertEmails
(
adminEmails
,
accountID
,
accountName
,
platform
,
dim
,
newUsed
,
siteName
)
}()
}
// getBalanceNotifyConfig reads global balance notification settings.
func
(
s
*
BalanceNotifyService
)
getBalanceNotifyConfig
(
ctx
context
.
Context
)
(
enabled
bool
,
threshold
float64
,
rechargeURL
string
)
{
keys
:=
[]
string
{
SettingKeyBalanceLowNotifyEnabled
,
SettingKeyBalanceLowNotifyThreshold
,
SettingKeyBalanceLowNotifyRechargeURL
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
false
,
0
,
""
}
enabled
=
settings
[
SettingKeyBalanceLowNotifyEnabled
]
==
"true"
if
v
:=
settings
[
SettingKeyBalanceLowNotifyThreshold
];
v
!=
""
{
if
f
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
threshold
=
f
}
}
rechargeURL
=
settings
[
SettingKeyBalanceLowNotifyRechargeURL
]
return
}
// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
func
(
s
*
BalanceNotifyService
)
isAccountQuotaNotifyEnabled
(
ctx
context
.
Context
)
bool
{
val
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAccountQuotaNotifyEnabled
)
if
err
!=
nil
{
return
false
}
return
val
==
"true"
}
// getAccountQuotaNotifyEmails reads admin notification emails from settings,
// filtering out disabled and unverified entries.
func
(
s
*
BalanceNotifyService
)
getAccountQuotaNotifyEmails
(
ctx
context
.
Context
)
[]
string
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAccountQuotaNotifyEmails
)
if
err
!=
nil
||
strings
.
TrimSpace
(
raw
)
==
""
||
raw
==
"[]"
{
return
nil
}
entries
:=
ParseNotifyEmails
(
raw
)
if
len
(
entries
)
==
0
{
return
nil
}
return
filterVerifiedEmails
(
entries
)
}
// getSiteName reads site name from settings with fallback.
func
(
s
*
BalanceNotifyService
)
getSiteName
(
ctx
context
.
Context
)
string
{
name
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
)
if
err
!=
nil
||
name
==
""
{
return
defaultSiteName
}
return
name
}
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
func
filterVerifiedEmails
(
entries
[]
NotifyEmailEntry
)
[]
string
{
var
recipients
[]
string
seen
:=
make
(
map
[
string
]
bool
)
for
_
,
entry
:=
range
entries
{
if
entry
.
Disabled
||
!
entry
.
Verified
{
continue
}
email
:=
strings
.
TrimSpace
(
entry
.
Email
)
if
email
==
""
{
continue
}
lower
:=
strings
.
ToLower
(
email
)
if
seen
[
lower
]
{
continue
}
seen
[
lower
]
=
true
recipients
=
append
(
recipients
,
email
)
}
return
recipients
}
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func
(
s
*
BalanceNotifyService
)
collectBalanceNotifyRecipients
(
user
*
User
)
[]
string
{
return
filterVerifiedEmails
(
user
.
BalanceNotifyExtraEmails
)
}
// sendEmails sends an email to all recipients with shared timeout and error logging.
func
(
s
*
BalanceNotifyService
)
sendEmails
(
recipients
[]
string
,
subject
,
body
string
,
logAttrs
...
any
)
{
if
len
(
recipients
)
==
0
{
slog
.
Warn
(
"sendEmails: no recipients"
,
"subject"
,
subject
)
return
}
for
_
,
to
:=
range
recipients
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
emailSendTimeout
)
if
err
:=
s
.
emailService
.
SendEmail
(
ctx
,
to
,
subject
,
body
);
err
!=
nil
{
attrs
:=
append
([]
any
{
"to"
,
to
,
"error"
,
err
},
logAttrs
...
)
slog
.
Error
(
"failed to send notification"
,
attrs
...
)
}
else
{
slog
.
Info
(
"notification email sent successfully"
,
"to"
,
to
,
"subject"
,
subject
)
}
cancel
()
}
}
// sendBalanceLowEmails sends balance low notification to all recipients.
func
(
s
*
BalanceNotifyService
)
sendBalanceLowEmails
(
recipients
[]
string
,
userName
,
userEmail
string
,
balance
,
threshold
float64
,
siteName
,
rechargeURL
string
)
{
displayName
:=
userName
if
displayName
==
""
{
displayName
=
userEmail
}
subject
:=
fmt
.
Sprintf
(
"[%s] 余额不足提醒 / Balance Low Alert"
,
sanitizeEmailHeader
(
siteName
))
body
:=
s
.
buildBalanceLowEmailBody
(
html
.
EscapeString
(
displayName
),
balance
,
threshold
,
html
.
EscapeString
(
siteName
),
rechargeURL
)
s
.
sendEmails
(
recipients
,
subject
,
body
,
"user_email"
,
userEmail
,
"balance"
,
balance
)
}
// sendQuotaAlertEmails sends quota alert notification to admin emails.
func
(
s
*
BalanceNotifyService
)
sendQuotaAlertEmails
(
adminEmails
[]
string
,
accountID
int64
,
accountName
,
platform
string
,
dim
quotaDim
,
used
float64
,
siteName
string
)
{
dimLabel
:=
quotaDimLabels
[
dim
.
name
]
if
dimLabel
==
""
{
dimLabel
=
dim
.
name
}
// Format the remaining-based threshold for display
thresholdDisplay
:=
fmt
.
Sprintf
(
"$%.2f"
,
dim
.
threshold
)
if
dim
.
thresholdType
==
thresholdTypePercentage
{
thresholdDisplay
=
fmt
.
Sprintf
(
"%.0f%%"
,
dim
.
threshold
)
}
remaining
:=
dim
.
limit
-
used
if
remaining
<
0
{
remaining
=
0
}
subject
:=
fmt
.
Sprintf
(
"[%s] 账号限额告警 / Account Quota Alert - %s"
,
sanitizeEmailHeader
(
siteName
),
sanitizeEmailHeader
(
accountName
))
body
:=
s
.
buildQuotaAlertEmailBody
(
accountID
,
html
.
EscapeString
(
accountName
),
html
.
EscapeString
(
platform
),
html
.
EscapeString
(
dimLabel
),
used
,
dim
.
limit
,
remaining
,
thresholdDisplay
,
html
.
EscapeString
(
siteName
))
s
.
sendEmails
(
adminEmails
,
subject
,
body
,
"account"
,
accountName
,
"dimension"
,
dim
.
name
)
}
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func
sanitizeEmailHeader
(
s
string
)
string
{
return
strings
.
NewReplacer
(
"
\r
"
,
""
,
"
\n
"
,
""
)
.
Replace
(
s
)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
// The recharge button is appended dynamically when rechargeURL is set.
const
balanceLowEmailTemplate
=
`<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.balance { font-size: 36px; font-weight: bold; color: #dc2626; margin: 20px 0; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.recharge-btn { display: inline-block; margin-top: 24px; padding: 12px 32px; background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: #fff; text-decoration: none; border-radius: 6px; font-size: 16px; font-weight: bold; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333;">%s,您的余额不足</p>
<p style="color: #666;">Dear %s, your balance is running low</p>
<div class="balance">$%.2f</div>
<div class="info">
<p>您的账户余额已低于提醒阈值 <strong>$%.2f</strong>。</p>
<p>Your account balance has fallen below the alert threshold of <strong>$%.2f</strong>.</p>
<p>请及时充值以免服务中断。</p>
<p>Please top up to avoid service interruption.</p>
</div>
%s
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay.
const
quotaAlertEmailTemplate
=
`<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #ef4444 0%%, #dc2626 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; }
.metric { display: flex; justify-content: space-between; padding: 12px 0; border-bottom: 1px solid #eee; }
.metric-label { color: #666; }
.metric-value { font-weight: bold; color: #333; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; text-align: center; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333; text-align: center;">账号限额告警 / Account Quota Alert</p>
<div class="metric"><span class="metric-label">账号 ID / Account ID</span><span class="metric-value">#%d</span></div>
<div class="metric"><span class="metric-label">账号 / Account</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">平台 / Platform</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">维度 / Dimension</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">已使用 / Used</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">限额 / Limit</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">剩余额度 / Remaining</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">提醒阈值 / Alert Threshold</span><span class="metric-value">%s</span></div>
<div class="info">
<p>账号剩余额度已低于提醒阈值,请及时关注。</p>
<p>Account remaining quota has fallen below the alert threshold.</p>
</div>
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// buildBalanceLowEmailBody builds HTML email for balance low notification.
func
(
s
*
BalanceNotifyService
)
buildBalanceLowEmailBody
(
userName
string
,
balance
,
threshold
float64
,
siteName
,
rechargeURL
string
)
string
{
rechargeBlock
:=
""
if
rechargeURL
!=
""
{
rechargeBlock
=
fmt
.
Sprintf
(
`<a href="%s" class="recharge-btn">立即充值 / Top Up Now</a>`
,
html
.
EscapeString
(
rechargeURL
))
}
return
fmt
.
Sprintf
(
balanceLowEmailTemplate
,
siteName
,
userName
,
userName
,
balance
,
threshold
,
threshold
,
rechargeBlock
)
}
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
func
(
s
*
BalanceNotifyService
)
buildQuotaAlertEmailBody
(
accountID
int64
,
accountName
,
platform
,
dimLabel
string
,
used
,
limit
,
remaining
float64
,
thresholdDisplay
,
siteName
string
)
string
{
limitStr
:=
fmt
.
Sprintf
(
"$%.2f"
,
limit
)
if
limit
<=
0
{
limitStr
=
"无限制 / Unlimited"
}
return
fmt
.
Sprintf
(
quotaAlertEmailTemplate
,
siteName
,
accountID
,
accountName
,
platform
,
dimLabel
,
used
,
limitStr
,
remaining
,
thresholdDisplay
)
}
backend/internal/service/balance_notify_service_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
// ---------- resolveBalanceThreshold ----------
func
TestResolveBalanceThreshold_Fixed
(
t
*
testing
.
T
)
{
// Fixed type always returns the raw threshold regardless of totalRecharged.
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
thresholdTypeFixed
,
1000
))
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
thresholdTypeFixed
,
0
))
require
.
Equal
(
t
,
0.0
,
resolveBalanceThreshold
(
0
,
thresholdTypeFixed
,
1000
))
}
func
TestResolveBalanceThreshold_Percentage
(
t
*
testing
.
T
)
{
// 10% of 1000 = 100
require
.
Equal
(
t
,
100.0
,
resolveBalanceThreshold
(
10
,
thresholdTypePercentage
,
1000
))
// 50% of 200 = 100
require
.
Equal
(
t
,
100.0
,
resolveBalanceThreshold
(
50
,
thresholdTypePercentage
,
200
))
}
func
TestResolveBalanceThreshold_PercentageZeroRecharged
(
t
*
testing
.
T
)
{
// When totalRecharged is 0, percentage falls through to raw threshold
// (treated as fixed). This is the defensive behavior.
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
thresholdTypePercentage
,
0
))
}
func
TestResolveBalanceThreshold_EmptyType
(
t
*
testing
.
T
)
{
// Empty type is treated as fixed (not percentage).
require
.
Equal
(
t
,
10.0
,
resolveBalanceThreshold
(
10
,
""
,
1000
))
}
// ---------- quotaDim.resolvedThreshold ----------
func
TestResolvedThreshold_FixedNormal
(
t
*
testing
.
T
)
{
// threshold=400 remaining, limit=1000 → usage trigger at 600
d
:=
quotaDim
{
threshold
:
400
,
thresholdType
:
thresholdTypeFixed
,
limit
:
1000
}
require
.
Equal
(
t
,
600.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_FixedThresholdExceedsLimit
(
t
*
testing
.
T
)
{
// threshold=1200, limit=1000 → returns negative, callers must skip
d
:=
quotaDim
{
threshold
:
1200
,
thresholdType
:
thresholdTypeFixed
,
limit
:
1000
}
require
.
Equal
(
t
,
-
200.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_FixedThresholdEqualsLimit
(
t
*
testing
.
T
)
{
// threshold=1000, limit=1000 → returns 0 (alert fires at 0 usage)
d
:=
quotaDim
{
threshold
:
1000
,
thresholdType
:
thresholdTypeFixed
,
limit
:
1000
}
require
.
Equal
(
t
,
0.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_PercentageNormal
(
t
*
testing
.
T
)
{
// threshold=30%, limit=1000 → usage trigger at 700 (remaining drops to 30%)
d
:=
quotaDim
{
threshold
:
30
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
InDelta
(
t
,
700.0
,
d
.
resolvedThreshold
(),
0.001
)
}
func
TestResolvedThreshold_PercentageZeroPercent
(
t
*
testing
.
T
)
{
// threshold=0%, limit=1000 → fires when remaining drops to 0 (usage=1000)
d
:=
quotaDim
{
threshold
:
0
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
InDelta
(
t
,
1000.0
,
d
.
resolvedThreshold
(),
0.001
)
}
func
TestResolvedThreshold_PercentageHundredPercent
(
t
*
testing
.
T
)
{
// threshold=100%, limit=1000 → fires immediately (remaining drops to 100% i.e. nothing used yet)
d
:=
quotaDim
{
threshold
:
100
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
InDelta
(
t
,
0.0
,
d
.
resolvedThreshold
(),
0.001
)
}
func
TestResolvedThreshold_PercentageOverHundred
(
t
*
testing
.
T
)
{
// threshold=150%, limit=1000 → returns negative (never triggers; callers skip)
d
:=
quotaDim
{
threshold
:
150
,
thresholdType
:
thresholdTypePercentage
,
limit
:
1000
}
require
.
Less
(
t
,
d
.
resolvedThreshold
(),
0.0
)
}
func
TestResolvedThreshold_ZeroLimit
(
t
*
testing
.
T
)
{
// limit=0 → returns 0 to avoid division and false alerts on unlimited quotas
d
:=
quotaDim
{
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
limit
:
0
}
require
.
Equal
(
t
,
0.0
,
d
.
resolvedThreshold
())
}
func
TestResolvedThreshold_NegativeLimit
(
t
*
testing
.
T
)
{
// Negative limit treated as 0
d
:=
quotaDim
{
threshold
:
100
,
thresholdType
:
thresholdTypeFixed
,
limit
:
-
10
}
require
.
Equal
(
t
,
0.0
,
d
.
resolvedThreshold
())
}
// ---------- sanitizeEmailHeader ----------
func
TestSanitizeEmailHeader_CRLF
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"Subject injected"
,
sanitizeEmailHeader
(
"Subject
\r\n
injected"
))
}
func
TestSanitizeEmailHeader_OnlyCR
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"foobar"
,
sanitizeEmailHeader
(
"foo
\r
bar"
))
}
func
TestSanitizeEmailHeader_OnlyLF
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"foobar"
,
sanitizeEmailHeader
(
"foo
\n
bar"
))
}
func
TestSanitizeEmailHeader_Clean
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"Sub2API"
,
sanitizeEmailHeader
(
"Sub2API"
))
}
func
TestSanitizeEmailHeader_Empty
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
sanitizeEmailHeader
(
""
))
}
func
TestSanitizeEmailHeader_MultipleNewlines
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"abc"
,
sanitizeEmailHeader
(
"a
\r\n
b
\r\n
c"
))
}
// ---------- buildQuotaDims ----------
func
TestBuildQuotaDims_AllDimensionsReturned
(
t
*
testing
.
T
)
{
// Use an account with quota notify config across all 3 dimensions.
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"quota_notify_daily_enabled"
:
true
,
"quota_notify_daily_threshold"
:
100.0
,
"quota_notify_daily_threshold_type"
:
thresholdTypeFixed
,
"quota_notify_weekly_enabled"
:
true
,
"quota_notify_weekly_threshold"
:
20.0
,
"quota_notify_weekly_threshold_type"
:
thresholdTypePercentage
,
"quota_notify_total_enabled"
:
false
,
"quota_daily_limit"
:
500.0
,
"quota_weekly_limit"
:
2000.0
,
"quota_limit"
:
10000.0
,
"quota_daily_used"
:
50.0
,
"quota_weekly_used"
:
300.0
,
"quota_used"
:
1000.0
,
},
}
dims
:=
buildQuotaDims
(
a
)
require
.
Len
(
t
,
dims
,
3
)
// Daily
require
.
Equal
(
t
,
quotaDimDaily
,
dims
[
0
]
.
name
)
require
.
True
(
t
,
dims
[
0
]
.
enabled
)
require
.
Equal
(
t
,
100.0
,
dims
[
0
]
.
threshold
)
require
.
Equal
(
t
,
thresholdTypeFixed
,
dims
[
0
]
.
thresholdType
)
require
.
Equal
(
t
,
500.0
,
dims
[
0
]
.
limit
)
require
.
Equal
(
t
,
50.0
,
dims
[
0
]
.
currentUsed
)
// Weekly
require
.
Equal
(
t
,
quotaDimWeekly
,
dims
[
1
]
.
name
)
require
.
True
(
t
,
dims
[
1
]
.
enabled
)
require
.
Equal
(
t
,
20.0
,
dims
[
1
]
.
threshold
)
require
.
Equal
(
t
,
thresholdTypePercentage
,
dims
[
1
]
.
thresholdType
)
require
.
Equal
(
t
,
2000.0
,
dims
[
1
]
.
limit
)
// Total
require
.
Equal
(
t
,
quotaDimTotal
,
dims
[
2
]
.
name
)
require
.
False
(
t
,
dims
[
2
]
.
enabled
)
require
.
Equal
(
t
,
10000.0
,
dims
[
2
]
.
limit
)
require
.
Equal
(
t
,
1000.0
,
dims
[
2
]
.
currentUsed
)
}
func
TestBuildQuotaDims_EmptyExtra
(
t
*
testing
.
T
)
{
// Missing fields default to zero/disabled.
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{},
}
dims
:=
buildQuotaDims
(
a
)
require
.
Len
(
t
,
dims
,
3
)
for
_
,
d
:=
range
dims
{
require
.
False
(
t
,
d
.
enabled
)
require
.
Equal
(
t
,
0.0
,
d
.
threshold
)
require
.
Equal
(
t
,
0.0
,
d
.
limit
)
}
}
// ---------- buildQuotaDimsFromState ----------
func
TestBuildQuotaDimsFromState_UsesStateValues
(
t
*
testing
.
T
)
{
// Usage values should come from the state, not the account.
a
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"quota_notify_daily_enabled"
:
true
,
"quota_notify_daily_threshold"
:
100.0
,
"quota_daily_used"
:
999.0
,
// should be ignored
"quota_daily_limit"
:
999.0
,
// should be ignored
},
}
state
:=
&
AccountQuotaState
{
DailyUsed
:
77.0
,
DailyLimit
:
500.0
,
WeeklyUsed
:
88.0
,
WeeklyLimit
:
2000.0
,
TotalUsed
:
99.0
,
TotalLimit
:
10000.0
,
}
dims
:=
buildQuotaDimsFromState
(
a
,
state
)
require
.
Len
(
t
,
dims
,
3
)
// Settings from account (enabled, threshold, thresholdType)
require
.
True
(
t
,
dims
[
0
]
.
enabled
)
require
.
Equal
(
t
,
100.0
,
dims
[
0
]
.
threshold
)
// Usage from state
require
.
Equal
(
t
,
77.0
,
dims
[
0
]
.
currentUsed
)
require
.
Equal
(
t
,
500.0
,
dims
[
0
]
.
limit
)
require
.
Equal
(
t
,
88.0
,
dims
[
1
]
.
currentUsed
)
require
.
Equal
(
t
,
2000.0
,
dims
[
1
]
.
limit
)
require
.
Equal
(
t
,
99.0
,
dims
[
2
]
.
currentUsed
)
require
.
Equal
(
t
,
10000.0
,
dims
[
2
]
.
limit
)
}
// ---------- collectBalanceNotifyRecipients ----------
func
TestCollectBalanceNotifyRecipients_Empty
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
nil
}
require
.
Empty
(
t
,
s
.
collectBalanceNotifyRecipients
(
u
))
}
func
TestCollectBalanceNotifyRecipients_FiltersDisabledAndUnverified
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
"a@example.com"
,
Verified
:
true
,
Disabled
:
false
},
{
Email
:
"b@example.com"
,
Verified
:
true
,
Disabled
:
true
},
// disabled
{
Email
:
"c@example.com"
,
Verified
:
false
,
Disabled
:
false
},
// unverified
{
Email
:
"d@example.com"
,
Verified
:
true
,
Disabled
:
false
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Equal
(
t
,
[]
string
{
"a@example.com"
,
"d@example.com"
},
got
)
}
func
TestCollectBalanceNotifyRecipients_DeduplicatesCaseInsensitive
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
"User@Example.com"
,
Verified
:
true
},
{
Email
:
"user@example.com"
,
Verified
:
true
},
{
Email
:
"USER@EXAMPLE.COM"
,
Verified
:
true
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Len
(
t
,
got
,
1
)
// The original casing of the first entry is preserved.
require
.
Equal
(
t
,
"User@Example.com"
,
got
[
0
])
}
func
TestCollectBalanceNotifyRecipients_SkipsEmpty
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
" "
,
Verified
:
true
},
{
Email
:
""
,
Verified
:
true
},
{
Email
:
"valid@example.com"
,
Verified
:
true
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Equal
(
t
,
[]
string
{
"valid@example.com"
},
got
)
}
func
TestCollectBalanceNotifyRecipients_TrimsWhitespace
(
t
*
testing
.
T
)
{
s
:=
&
BalanceNotifyService
{}
u
:=
&
User
{
BalanceNotifyExtraEmails
:
[]
NotifyEmailEntry
{
{
Email
:
" trimmed@example.com "
,
Verified
:
true
},
},
}
got
:=
s
.
collectBalanceNotifyRecipients
(
u
)
require
.
Equal
(
t
,
[]
string
{
"trimmed@example.com"
},
got
)
}
backend/internal/service/billing_service_test.go
View file @
0b746501
...
@@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
...
@@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require
.
InDelta
(
t
,
0.134
*
3
,
cost
.
ActualCost
,
1e-10
)
require
.
InDelta
(
t
,
0.134
*
3
,
cost
.
ActualCost
,
1e-10
)
}
}
func
TestIsModelSupported
(
t
*
testing
.
T
)
{
func
TestIsModelSupported
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
svc
:=
newTestBillingService
()
...
@@ -719,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
...
@@ -719,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
require
.
InDelta
(
t
,
1.5
,
pricing
.
LongContextInputMultiplier
,
1e-12
)
require
.
InDelta
(
t
,
1.5
,
pricing
.
LongContextInputMultiplier
,
1e-12
)
require
.
InDelta
(
t
,
1.25
,
pricing
.
LongContextOutputMultiplier
,
1e-12
)
require
.
InDelta
(
t
,
1.25
,
pricing
.
LongContextOutputMultiplier
,
1e-12
)
}
}
// ---------------------------------------------------------------------------
// GetModelPricingWithChannel
// ---------------------------------------------------------------------------
func
TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
pricing
)
// Should be identical to GetModelPricing
original
,
err
:=
svc
.
GetModelPricing
(
"claude-sonnet-4"
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
original
.
InputPricePerToken
,
pricing
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
original
.
OutputPricePerToken
,
pricing
.
OutputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
original
.
CacheCreationPricePerToken
,
pricing
.
CacheCreationPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
original
.
CacheReadPricePerToken
,
pricing
.
CacheReadPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_OverrideInputPriceOnly
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
99e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// InputPrice overridden (both normal and priority)
require
.
InDelta
(
t
,
99e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
99e-6
,
pricing
.
InputPricePerTokenPriority
,
1e-12
)
// OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6)
require
.
InDelta
(
t
,
15e-6
,
pricing
.
OutputPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_OverrideOutputPriceOnly
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
OutputPrice
:
testPtrFloat64
(
88e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// OutputPrice overridden
require
.
InDelta
(
t
,
88e-6
,
pricing
.
OutputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
88e-6
,
pricing
.
OutputPricePerTokenPriority
,
1e-12
)
// InputPrice unchanged (claude-sonnet-4 fallback = 3e-6)
require
.
InDelta
(
t
,
3e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_OverrideAllFields
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
10e-6
),
OutputPrice
:
testPtrFloat64
(
20e-6
),
CacheWritePrice
:
testPtrFloat64
(
5e-6
),
CacheReadPrice
:
testPtrFloat64
(
1e-6
),
ImageOutputPrice
:
testPtrFloat64
(
50e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
10e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
10e-6
,
pricing
.
InputPricePerTokenPriority
,
1e-12
)
require
.
InDelta
(
t
,
20e-6
,
pricing
.
OutputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
20e-6
,
pricing
.
OutputPricePerTokenPriority
,
1e-12
)
require
.
InDelta
(
t
,
5e-6
,
pricing
.
CacheCreationPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
5e-6
,
pricing
.
CacheCreation5mPrice
,
1e-12
)
require
.
InDelta
(
t
,
5e-6
,
pricing
.
CacheCreation1hPrice
,
1e-12
)
require
.
InDelta
(
t
,
1e-6
,
pricing
.
CacheReadPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
1e-6
,
pricing
.
CacheReadPricePerTokenPriority
,
1e-12
)
require
.
InDelta
(
t
,
50e-6
,
pricing
.
ImageOutputPricePerToken
,
1e-12
)
}
func
TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
CacheWritePrice
:
testPtrFloat64
(
7e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h
require
.
InDelta
(
t
,
7e-6
,
pricing
.
CacheCreationPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
7e-6
,
pricing
.
CacheCreation5mPrice
,
1e-12
)
require
.
InDelta
(
t
,
7e-6
,
pricing
.
CacheCreation1hPrice
,
1e-12
)
}
func
TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
CacheReadPrice
:
testPtrFloat64
(
2e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"claude-sonnet-4"
,
chPricing
)
require
.
NoError
(
t
,
err
)
// CacheReadPrice should set both normal and priority
require
.
InDelta
(
t
,
2e-6
,
pricing
.
CacheReadPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
2e-6
,
pricing
.
CacheReadPricePerTokenPriority
,
1e-12
)
}
func
TestGetModelPricingWithChannel_UnknownModelReturnsError
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
chPricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
1e-6
),
}
pricing
,
err
:=
svc
.
GetModelPricingWithChannel
(
"totally-unknown-model"
,
chPricing
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
pricing
)
require
.
Contains
(
t
,
err
.
Error
(),
"pricing not found"
)
}
backend/internal/service/billing_service_unified_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// CalculateCostUnified
// ---------------------------------------------------------------------------
func
TestCalculateCostUnified_NilResolver_FallsBackToOldPath
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
input
:=
CostInput
{
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.0
,
Resolver
:
nil
,
// no resolver
}
cost
,
err
:=
svc
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
// Should match the old-path result exactly
expected
,
err
:=
svc
.
calculateCostInternal
(
"claude-sonnet-4"
,
tokens
,
1.0
,
""
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
expected
.
TotalCost
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
require
.
Empty
(
t
,
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_TokenMode
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
input
:=
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.5
,
Resolver
:
resolver
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
expectedTotal
:=
1000
*
3e-6
+
500
*
15e-6
require
.
InDelta
(
t
,
expectedTotal
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
expectedTotal
*
1.5
,
cost
.
ActualCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModeToken
),
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_PerRequestMode
(
t
*
testing
.
T
)
{
// Set up a ChannelService with a per-request pricing channel
cs
:=
newTestChannelServiceWithCache
(
t
,
&
channelCache
{
pricingByGroupModel
:
map
[
channelModelKey
]
*
ChannelModelPricing
{
{
groupID
:
1
,
model
:
"claude-sonnet-4"
}
:
{
BillingMode
:
BillingModePerRequest
,
PerRequestPrice
:
testPtrFloat64
(
0.05
),
},
},
channelByGroupID
:
map
[
int64
]
*
Channel
{
1
:
{
ID
:
1
,
Status
:
StatusActive
},
},
groupPlatform
:
map
[
int64
]
string
{
1
:
""
},
wildcardByGroupPlatform
:
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
{},
mappingByGroupModel
:
map
[
channelModelKey
]
string
{},
wildcardMappingByGP
:
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
{},
byID
:
map
[
int64
]
*
Channel
{},
})
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
cs
,
bs
)
groupID
:=
int64
(
1
)
input
:=
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
GroupID
:
&
groupID
,
Tokens
:
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
},
RequestCount
:
3
,
RateMultiplier
:
2.0
,
Resolver
:
resolver
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// 3 requests * $0.05 = $0.15
require
.
InDelta
(
t
,
0.15
,
cost
.
TotalCost
,
1e-10
)
// ActualCost = 0.15 * 2.0 = 0.30
require
.
InDelta
(
t
,
0.30
,
cost
.
ActualCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModePerRequest
),
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_ImageMode
(
t
*
testing
.
T
)
{
cs
:=
newTestChannelServiceWithCache
(
t
,
&
channelCache
{
pricingByGroupModel
:
map
[
channelModelKey
]
*
ChannelModelPricing
{
{
groupID
:
2
,
model
:
"gemini-image"
}
:
{
BillingMode
:
BillingModeImage
,
PerRequestPrice
:
testPtrFloat64
(
0.10
),
},
},
channelByGroupID
:
map
[
int64
]
*
Channel
{
2
:
{
ID
:
2
,
Status
:
StatusActive
},
},
groupPlatform
:
map
[
int64
]
string
{
2
:
""
},
wildcardByGroupPlatform
:
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
{},
mappingByGroupModel
:
map
[
channelModelKey
]
string
{},
wildcardMappingByGP
:
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
{},
byID
:
map
[
int64
]
*
Channel
{},
})
bs
:=
&
BillingService
{
cfg
:
&
config
.
Config
{},
fallbackPrices
:
map
[
string
]
*
ModelPricing
{},
}
resolver
:=
NewModelPricingResolver
(
cs
,
bs
)
groupID
:=
int64
(
2
)
input
:=
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"gemini-image"
,
GroupID
:
&
groupID
,
Tokens
:
UsageTokens
{},
RequestCount
:
2
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
input
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// 2 * $0.10 = $0.20
require
.
InDelta
(
t
,
0.20
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
0.20
,
cost
.
ActualCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModeImage
),
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
costZero
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
0
,
// should default to 1.0
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
costOne
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
costOne
.
ActualCost
,
costZero
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
}
costNeg
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
-
5.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
costOne
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
tokens
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
costOne
.
ActualCost
,
costNeg
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostUnified_BillingModeFieldFilled
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
cost
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
UsageTokens
{
InputTokens
:
100
},
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"token"
,
cost
.
BillingMode
)
}
func
TestCalculateCostUnified_UsesPreResolvedPricing
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingService
()
resolver
:=
NewModelPricingResolver
(
nil
,
bs
)
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
preResolved
:=
&
ResolvedPricing
{
Mode
:
BillingModePerRequest
,
DefaultPerRequestPrice
:
0.07
,
}
cost
,
err
:=
bs
.
CalculateCostUnified
(
CostInput
{
Ctx
:
context
.
Background
(),
Model
:
"claude-sonnet-4"
,
Tokens
:
UsageTokens
{
InputTokens
:
100
},
RequestCount
:
2
,
RateMultiplier
:
1.0
,
Resolver
:
resolver
,
Resolved
:
preResolved
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
cost
)
// 2 * $0.07 = $0.14
require
.
InDelta
(
t
,
0.14
,
cost
.
TotalCost
,
1e-10
)
require
.
Equal
(
t
,
string
(
BillingModePerRequest
),
cost
.
BillingMode
)
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
// cache snapshot, bypassing the repository layer entirely.
func
newTestChannelServiceWithCache
(
t
*
testing
.
T
,
cache
*
channelCache
)
*
ChannelService
{
t
.
Helper
()
cs
:=
&
ChannelService
{}
cache
.
loadedAt
=
time
.
Now
()
cs
.
cache
.
Store
(
cache
)
return
cs
}
backend/internal/service/channel.go
View file @
0b746501
...
@@ -37,8 +37,10 @@ type Channel struct {
...
@@ -37,8 +37,10 @@ type Channel struct {
Name
string
Name
string
Description
string
Description
string
Status
string
Status
string
BillingModelSource
string
// "requested", "upstream", or "channel_mapped"
BillingModelSource
string
// "requested", "upstream", or "channel_mapped"
RestrictModels
bool
// 是否限制模型(仅允许定价列表中的模型)
RestrictModels
bool
// 是否限制模型(仅允许定价列表中的模型)
Features
string
// 渠道特性描述(JSON 数组),用于支付页面展示
FeaturesConfig
map
[
string
]
any
// 渠道功能配置(如 web search emulation)
CreatedAt
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
UpdatedAt
time
.
Time
...
@@ -48,6 +50,25 @@ type Channel struct {
...
@@ -48,6 +50,25 @@ type Channel struct {
ModelPricing
[]
ChannelModelPricing
ModelPricing
[]
ChannelModelPricing
// 渠道级模型映射(按平台分组:platform → {src→dst})
// 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping
map
[
string
]
map
[
string
]
string
ModelMapping
map
[
string
]
map
[
string
]
string
// 账号统计定价
ApplyPricingToAccountStats
bool
// 是否应用渠道模型定价到账号统计
AccountStatsPricingRules
[]
AccountStatsPricingRule
// 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
}
// AccountStatsPricingRule 账号统计定价规则
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
// 多条规则按 SortOrder 排序,先命中为准。
type
AccountStatsPricingRule
struct
{
ID
int64
ChannelID
int64
Name
string
GroupIDs
[]
int64
AccountIDs
[]
int64
SortOrder
int
Pricing
[]
ChannelModelPricing
// 规则内的模型定价(复用现有定价结构)
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
}
}
// ChannelModelPricing 渠道模型定价条目
// ChannelModelPricing 渠道模型定价条目
...
@@ -176,9 +197,58 @@ func (c *Channel) Clone() *Channel {
...
@@ -176,9 +197,58 @@ func (c *Channel) Clone() *Channel {
cp
.
ModelMapping
[
platform
]
=
inner
cp
.
ModelMapping
[
platform
]
=
inner
}
}
}
}
if
c
.
FeaturesConfig
!=
nil
{
cp
.
FeaturesConfig
=
deepCopyFeaturesConfig
(
c
.
FeaturesConfig
)
}
if
c
.
AccountStatsPricingRules
!=
nil
{
cp
.
AccountStatsPricingRules
=
make
([]
AccountStatsPricingRule
,
len
(
c
.
AccountStatsPricingRules
))
for
i
,
rule
:=
range
c
.
AccountStatsPricingRules
{
cp
.
AccountStatsPricingRules
[
i
]
=
rule
if
rule
.
GroupIDs
!=
nil
{
cp
.
AccountStatsPricingRules
[
i
]
.
GroupIDs
=
make
([]
int64
,
len
(
rule
.
GroupIDs
))
copy
(
cp
.
AccountStatsPricingRules
[
i
]
.
GroupIDs
,
rule
.
GroupIDs
)
}
if
rule
.
AccountIDs
!=
nil
{
cp
.
AccountStatsPricingRules
[
i
]
.
AccountIDs
=
make
([]
int64
,
len
(
rule
.
AccountIDs
))
copy
(
cp
.
AccountStatsPricingRules
[
i
]
.
AccountIDs
,
rule
.
AccountIDs
)
}
if
rule
.
Pricing
!=
nil
{
cp
.
AccountStatsPricingRules
[
i
]
.
Pricing
=
make
([]
ChannelModelPricing
,
len
(
rule
.
Pricing
))
for
j
:=
range
rule
.
Pricing
{
cp
.
AccountStatsPricingRules
[
i
]
.
Pricing
[
j
]
=
rule
.
Pricing
[
j
]
.
Clone
()
}
}
}
}
return
&
cp
return
&
cp
}
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func
(
c
*
Channel
)
IsWebSearchEmulationEnabled
(
platform
string
)
bool
{
if
c
==
nil
||
c
.
FeaturesConfig
==
nil
{
return
false
}
wse
,
ok
:=
c
.
FeaturesConfig
[
featureKeyWebSearchEmulation
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
false
}
enabled
,
ok
:=
wse
[
platform
]
.
(
bool
)
return
ok
&&
enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func
deepCopyFeaturesConfig
(
src
map
[
string
]
any
)
map
[
string
]
any
{
dst
:=
make
(
map
[
string
]
any
,
len
(
src
))
for
k
,
v
:=
range
src
{
if
inner
,
ok
:=
v
.
(
map
[
string
]
any
);
ok
{
dst
[
k
]
=
deepCopyFeaturesConfig
(
inner
)
}
else
{
dst
[
k
]
=
v
}
}
return
dst
}
// ValidateIntervals 校验区间列表的合法性。
// ValidateIntervals 校验区间列表的合法性。
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
...
...
backend/internal/service/channel_service.go
View file @
0b746501
...
@@ -81,9 +81,9 @@ type wildcardMappingEntry struct {
...
@@ -81,9 +81,9 @@ type wildcardMappingEntry struct {
type
channelCache
struct
{
type
channelCache
struct
{
// 热路径查找
// 热路径查找
pricingByGroupModel
map
[
channelModelKey
]
*
ChannelModelPricing
// (groupID, platform, model) → 定价
pricingByGroupModel
map
[
channelModelKey
]
*
ChannelModelPricing
// (groupID, platform, model) → 定价
wildcardByGroupPlatform
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
// (groupID, platform) → 通配符定价(
前缀长度降序
)
wildcardByGroupPlatform
map
[
channelGroupPlatformKey
][]
*
wildcardPricingEntry
// (groupID, platform) → 通配符定价(
按配置顺序,先匹配先使用
)
mappingByGroupModel
map
[
channelModelKey
]
string
// (groupID, platform, model) → 映射目标
mappingByGroupModel
map
[
channelModelKey
]
string
// (groupID, platform, model) → 映射目标
wildcardMappingByGP
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
// (groupID, platform) → 通配符映射(
前缀长度降序
)
wildcardMappingByGP
map
[
channelGroupPlatformKey
][]
*
wildcardMappingEntry
// (groupID, platform) → 通配符映射(
按配置顺序,先匹配先使用
)
channelByGroupID
map
[
int64
]
*
Channel
// groupID → 渠道
channelByGroupID
map
[
int64
]
*
Channel
// groupID → 渠道
groupPlatform
map
[
int64
]
string
// groupID → platform
groupPlatform
map
[
int64
]
string
// groupID → platform
...
@@ -315,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
...
@@ -315,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
expandMappingToCache
(
cache
,
ch
,
gid
,
platform
)
expandMappingToCache
(
cache
,
ch
,
gid
,
platform
)
}
}
}
}
return
cache
return
cache
}
}
...
@@ -415,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
...
@@ -415,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return
ch
.
Clone
(),
nil
return
ch
.
Clone
(),
nil
}
}
// GetGroupPlatform 获取分组的平台标识(从缓存)
func
(
s
*
ChannelService
)
GetGroupPlatform
(
ctx
context
.
Context
,
groupID
int64
)
string
{
cache
,
err
:=
s
.
loadCache
(
ctx
)
if
err
!=
nil
{
return
""
}
return
cache
.
groupPlatform
[
groupID
]
}
// channelLookup 热路径公共查找结果
// channelLookup 热路径公共查找结果
type
channelLookup
struct
{
type
channelLookup
struct
{
cache
*
channelCache
cache
*
channelCache
...
@@ -556,13 +566,19 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
...
@@ -556,13 +566,19 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
// Create 和 Update 共用此函数,避免重复。
func
validateChannelConfig
(
pricing
[]
ChannelModelPricing
,
mapping
map
[
string
]
map
[
string
]
string
)
error
{
func
validateChannelConfig
(
pricing
[]
ChannelModelPricing
,
mapping
map
[
string
]
map
[
string
]
string
)
error
{
if
err
:=
validate
NoConfl
ic
t
ing
Model
s
(
pricing
);
err
!=
nil
{
if
err
:=
validate
Pr
icing
Entrie
s
(
pricing
);
err
!=
nil
{
return
err
return
err
}
}
if
err
:=
validatePricingIntervals
(
pricing
);
err
!=
nil
{
return
validateNoConflictingMappings
(
mapping
)
}
// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验),
// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。
func
validatePricingEntries
(
pricing
[]
ChannelModelPricing
)
error
{
if
err
:=
validateNoConflictingModels
(
pricing
);
err
!=
nil
{
return
err
return
err
}
}
if
err
:=
validate
NoConfl
ic
t
ing
Mappings
(
mapp
ing
);
err
!=
nil
{
if
err
:=
validate
Pr
icing
Intervals
(
pric
ing
);
err
!=
nil
{
return
err
return
err
}
}
return
validatePricingBillingMode
(
pricing
)
return
validatePricingBillingMode
(
pricing
)
...
@@ -655,14 +671,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
...
@@ -655,14 +671,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
}
channel
:=
&
Channel
{
channel
:=
&
Channel
{
Name
:
input
.
Name
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Description
:
input
.
Description
,
Status
:
StatusActive
,
Status
:
StatusActive
,
BillingModelSource
:
input
.
BillingModelSource
,
BillingModelSource
:
input
.
BillingModelSource
,
RestrictModels
:
input
.
RestrictModels
,
RestrictModels
:
input
.
RestrictModels
,
GroupIDs
:
input
.
GroupIDs
,
GroupIDs
:
input
.
GroupIDs
,
ModelPricing
:
input
.
ModelPricing
,
ModelPricing
:
input
.
ModelPricing
,
ModelMapping
:
input
.
ModelMapping
,
ModelMapping
:
input
.
ModelMapping
,
Features
:
input
.
Features
,
FeaturesConfig
:
input
.
FeaturesConfig
,
ApplyPricingToAccountStats
:
input
.
ApplyPricingToAccountStats
,
AccountStatsPricingRules
:
input
.
AccountStatsPricingRules
,
}
}
if
channel
.
BillingModelSource
==
""
{
if
channel
.
BillingModelSource
==
""
{
channel
.
BillingModelSource
=
BillingModelSourceChannelMapped
channel
.
BillingModelSource
=
BillingModelSourceChannelMapped
...
@@ -671,6 +691,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
...
@@ -671,6 +691,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if
err
:=
validateChannelConfig
(
channel
.
ModelPricing
,
channel
.
ModelMapping
);
err
!=
nil
{
if
err
:=
validateChannelConfig
(
channel
.
ModelPricing
,
channel
.
ModelMapping
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
for
i
,
rule
:=
range
channel
.
AccountStatsPricingRules
{
if
err
:=
validatePricingEntries
(
rule
.
Pricing
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"account stats pricing rule #%d: %w"
,
i
+
1
,
err
)
}
}
if
err
:=
s
.
repo
.
Create
(
ctx
,
channel
);
err
!=
nil
{
if
err
:=
s
.
repo
.
Create
(
ctx
,
channel
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create channel: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"create channel: %w"
,
err
)
...
@@ -699,6 +724,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
...
@@ -699,6 +724,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if
err
:=
validateChannelConfig
(
channel
.
ModelPricing
,
channel
.
ModelMapping
);
err
!=
nil
{
if
err
:=
validateChannelConfig
(
channel
.
ModelPricing
,
channel
.
ModelMapping
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
for
i
,
rule
:=
range
channel
.
AccountStatsPricingRules
{
if
err
:=
validatePricingEntries
(
rule
.
Pricing
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"account stats pricing rule #%d: %w"
,
i
+
1
,
err
)
}
}
oldGroupIDs
:=
s
.
getOldGroupIDs
(
ctx
,
id
)
oldGroupIDs
:=
s
.
getOldGroupIDs
(
ctx
,
id
)
...
@@ -733,6 +763,9 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
...
@@ -733,6 +763,9 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if
input
.
RestrictModels
!=
nil
{
if
input
.
RestrictModels
!=
nil
{
channel
.
RestrictModels
=
*
input
.
RestrictModels
channel
.
RestrictModels
=
*
input
.
RestrictModels
}
}
if
input
.
Features
!=
nil
{
channel
.
Features
=
*
input
.
Features
}
if
input
.
GroupIDs
!=
nil
{
if
input
.
GroupIDs
!=
nil
{
if
err
:=
s
.
checkGroupConflicts
(
ctx
,
channel
.
ID
,
*
input
.
GroupIDs
);
err
!=
nil
{
if
err
:=
s
.
checkGroupConflicts
(
ctx
,
channel
.
ID
,
*
input
.
GroupIDs
);
err
!=
nil
{
return
err
return
err
...
@@ -748,6 +781,15 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
...
@@ -748,6 +781,15 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if
input
.
BillingModelSource
!=
""
{
if
input
.
BillingModelSource
!=
""
{
channel
.
BillingModelSource
=
input
.
BillingModelSource
channel
.
BillingModelSource
=
input
.
BillingModelSource
}
}
if
input
.
FeaturesConfig
!=
nil
{
channel
.
FeaturesConfig
=
input
.
FeaturesConfig
}
if
input
.
ApplyPricingToAccountStats
!=
nil
{
channel
.
ApplyPricingToAccountStats
=
*
input
.
ApplyPricingToAccountStats
}
if
input
.
AccountStatsPricingRules
!=
nil
{
channel
.
AccountStatsPricingRules
=
*
input
.
AccountStatsPricingRules
}
return
nil
return
nil
}
}
...
@@ -913,23 +955,31 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
...
@@ -913,23 +955,31 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
// CreateChannelInput 创建渠道输入
// CreateChannelInput 创建渠道输入
type
CreateChannelInput
struct
{
type
CreateChannelInput
struct
{
Name
string
Name
string
Description
string
Description
string
GroupIDs
[]
int64
GroupIDs
[]
int64
ModelPricing
[]
ChannelModelPricing
ModelPricing
[]
ChannelModelPricing
ModelMapping
map
[
string
]
map
[
string
]
string
// platform → {src→dst}
ModelMapping
map
[
string
]
map
[
string
]
string
// platform → {src→dst}
BillingModelSource
string
BillingModelSource
string
RestrictModels
bool
RestrictModels
bool
Features
string
FeaturesConfig
map
[
string
]
any
ApplyPricingToAccountStats
bool
AccountStatsPricingRules
[]
AccountStatsPricingRule
}
}
// UpdateChannelInput 更新渠道输入
// UpdateChannelInput 更新渠道输入
type
UpdateChannelInput
struct
{
type
UpdateChannelInput
struct
{
Name
string
Name
string
Description
*
string
Description
*
string
Status
string
Status
string
GroupIDs
*
[]
int64
GroupIDs
*
[]
int64
ModelPricing
*
[]
ChannelModelPricing
ModelPricing
*
[]
ChannelModelPricing
ModelMapping
map
[
string
]
map
[
string
]
string
// platform → {src→dst}
ModelMapping
map
[
string
]
map
[
string
]
string
// platform → {src→dst}
BillingModelSource
string
BillingModelSource
string
RestrictModels
*
bool
RestrictModels
*
bool
Features
*
string
FeaturesConfig
map
[
string
]
any
ApplyPricingToAccountStats
*
bool
AccountStatsPricingRules
*
[]
AccountStatsPricingRule
}
}
backend/internal/service/channel_websearch_test.go
0 → 100644
View file @
0b746501
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestChannel_IsWebSearchEmulationEnabled_Enabled
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
true
},
},
}
require
.
True
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
true
},
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"openai"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_Disabled
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
false
},
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
nil
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_NilChannel
(
t
*
testing
.
T
)
{
var
c
*
Channel
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_WrongStructure
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
true
,
// not a map
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
func
TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool
(
t
*
testing
.
T
)
{
c
:=
&
Channel
{
FeaturesConfig
:
map
[
string
]
any
{
featureKeyWebSearchEmulation
:
map
[
string
]
any
{
"anthropic"
:
"yes"
},
},
}
require
.
False
(
t
,
c
.
IsWebSearchEmulationEnabled
(
"anthropic"
))
}
backend/internal/service/concurrency_service.go
View file @
0b746501
...
@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
...
@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}()
}
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Returns a map of accountID -> current concurrency count
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
if
len
(
accountIDs
)
==
0
{
return
map
[
int64
]
int
{},
nil
return
map
[
int64
]
int
{},
nil
...
@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
...
@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
}
return
result
,
nil
return
result
,
nil
}
}
return
s
.
cache
.
GetAccountConcurrencyBatch
(
ctx
,
accountIDs
)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
defer
cancel
()
return
s
.
cache
.
GetAccountConcurrencyBatch
(
redisCtx
,
accountIDs
)
}
}
backend/internal/service/domain_constants.go
View file @
0b746501
...
@@ -249,6 +249,18 @@ const (
...
@@ -249,6 +249,18 @@ const (
SettingKeyEnableMetadataPassthrough
=
"enable_metadata_passthrough"
SettingKeyEnableMetadataPassthrough
=
"enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning
=
"enable_cch_signing"
SettingKeyEnableCCHSigning
=
"enable_cch_signing"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled
=
"balance_low_notify_enabled"
// 全局开关
SettingKeyBalanceLowNotifyThreshold
=
"balance_low_notify_threshold"
// 默认阈值(USD)
SettingKeyBalanceLowNotifyRechargeURL
=
"balance_low_notify_recharge_url"
// 充值页面 URL
// Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled
=
"account_quota_notify_enabled"
// 全局开关
SettingKeyAccountQuotaNotifyEmails
=
"account_quota_notify_emails"
// 管理员通知邮箱列表(JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig
=
"web_search_emulation_config"
// JSON 配置
)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
...
...
backend/internal/service/email_service.go
View file @
0b746501
...
@@ -7,8 +7,9 @@ import (
...
@@ -7,8 +7,9 @@ import (
"crypto/tls"
"crypto/tls"
"encoding/hex"
"encoding/hex"
"fmt"
"fmt"
"log"
"log
/slog
"
"math/big"
"math/big"
"net"
"net/smtp"
"net/smtp"
"net/url"
"net/url"
"strconv"
"strconv"
...
@@ -34,6 +35,11 @@ type EmailCache interface {
...
@@ -34,6 +35,11 @@ type EmailCache interface {
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Notify email verification code methods
GetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteNotifyVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
...
@@ -43,6 +49,10 @@ type EmailCache interface {
...
@@ -43,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently)
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
// Notify code rate limiting per user
IncrNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
,
window
time
.
Duration
)
(
int64
,
error
)
GetNotifyCodeUserRate
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
}
}
// VerificationCodeData represents verification code data
// VerificationCodeData represents verification code data
...
@@ -50,6 +60,7 @@ type VerificationCodeData struct {
...
@@ -50,6 +60,7 @@ type VerificationCodeData struct {
Code
string
Code
string
Attempts
int
Attempts
int
CreatedAt
time
.
Time
CreatedAt
time
.
Time
ExpiresAt
time
.
Time
// absolute expiry; used to preserve remaining TTL when updating attempts
}
}
// PasswordResetTokenData represents password reset token data
// PasswordResetTokenData represents password reset token data
...
@@ -146,11 +157,18 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
...
@@ -146,11 +157,18 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
return
s
.
SendEmailWithConfig
(
config
,
to
,
subject
,
body
)
return
s
.
SendEmailWithConfig
(
config
,
to
,
subject
,
body
)
}
}
const
smtpDialTimeout
=
10
*
time
.
Second
const
smtpIOTimeout
=
20
*
time
.
Second
// SendEmailWithConfig 使用指定配置发送邮件
// SendEmailWithConfig 使用指定配置发送邮件
func
(
s
*
EmailService
)
SendEmailWithConfig
(
config
*
SMTPConfig
,
to
,
subject
,
body
string
)
error
{
func
(
s
*
EmailService
)
SendEmailWithConfig
(
config
*
SMTPConfig
,
to
,
subject
,
body
string
)
error
{
from
:=
config
.
From
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
to
=
sanitizeEmailHeader
(
to
)
subject
=
sanitizeEmailHeader
(
subject
)
from
:=
sanitizeEmailHeader
(
config
.
From
)
if
config
.
FromName
!=
""
{
if
config
.
FromName
!=
""
{
from
=
fmt
.
Sprintf
(
"%s <%s>"
,
config
.
FromName
,
config
.
From
)
from
=
fmt
.
Sprintf
(
"%s <%s>"
,
sanitizeEmailHeader
(
config
.
FromName
)
,
sanitizeEmailHeader
(
config
.
From
)
)
}
}
msg
:=
fmt
.
Sprintf
(
"From: %s
\r\n
To: %s
\r\n
Subject: %s
\r\n
MIME-Version: 1.0
\r\n
Content-Type: text/html; charset=UTF-8
\r\n\r\n
%s"
,
msg
:=
fmt
.
Sprintf
(
"From: %s
\r\n
To: %s
\r\n
Subject: %s
\r\n
MIME-Version: 1.0
\r\n
Content-Type: text/html; charset=UTF-8
\r\n\r\n
%s"
,
...
@@ -163,7 +181,54 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
...
@@ -163,7 +181,54 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
return
s
.
sendMailTLS
(
addr
,
auth
,
config
.
From
,
to
,
[]
byte
(
msg
),
config
.
Host
)
return
s
.
sendMailTLS
(
addr
,
auth
,
config
.
From
,
to
,
[]
byte
(
msg
),
config
.
Host
)
}
}
return
smtp
.
SendMail
(
addr
,
auth
,
config
.
From
,
[]
string
{
to
},
[]
byte
(
msg
))
return
s
.
sendMailPlain
(
addr
,
auth
,
config
.
From
,
to
,
[]
byte
(
msg
),
config
.
Host
)
}
// sendMailPlain sends mail without TLS using a dialer with timeout.
func
(
s
*
EmailService
)
sendMailPlain
(
addr
string
,
auth
smtp
.
Auth
,
from
,
to
string
,
msg
[]
byte
,
host
string
)
error
{
dialer
:=
&
net
.
Dialer
{
Timeout
:
smtpDialTimeout
}
conn
,
err
:=
dialer
.
Dial
(
"tcp"
,
addr
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp dial: %w"
,
err
)
}
_
=
conn
.
SetDeadline
(
time
.
Now
()
.
Add
(
smtpIOTimeout
))
defer
func
()
{
_
=
conn
.
Close
()
}()
client
,
err
:=
smtp
.
NewClient
(
conn
,
host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"new smtp client: %w"
,
err
)
}
defer
func
()
{
_
=
client
.
Close
()
}()
// Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it.
// This mirrors the behavior of smtp.SendMail which we replaced for timeout support.
if
ok
,
_
:=
client
.
Extension
(
"STARTTLS"
);
ok
{
if
err
=
client
.
StartTLS
(
&
tls
.
Config
{
ServerName
:
host
,
MinVersion
:
tls
.
VersionTLS12
});
err
!=
nil
{
return
fmt
.
Errorf
(
"starttls: %w"
,
err
)
}
}
if
err
=
client
.
Auth
(
auth
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp auth: %w"
,
err
)
}
if
err
=
client
.
Mail
(
from
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp mail: %w"
,
err
)
}
if
err
=
client
.
Rcpt
(
to
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp rcpt: %w"
,
err
)
}
w
,
err
:=
client
.
Data
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp data: %w"
,
err
)
}
if
_
,
err
=
w
.
Write
(
msg
);
err
!=
nil
{
return
fmt
.
Errorf
(
"write msg: %w"
,
err
)
}
if
err
=
w
.
Close
();
err
!=
nil
{
return
fmt
.
Errorf
(
"close writer: %w"
,
err
)
}
_
=
client
.
Quit
()
return
nil
}
}
// sendMailTLS 使用TLS发送邮件
// sendMailTLS 使用TLS发送邮件
...
@@ -174,10 +239,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
...
@@ -174,10 +239,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
MinVersion
:
tls
.
VersionTLS12
,
MinVersion
:
tls
.
VersionTLS12
,
}
}
conn
,
err
:=
tls
.
Dial
(
"tcp"
,
addr
,
tlsConfig
)
dialer
:=
&
net
.
Dialer
{
Timeout
:
smtpDialTimeout
}
conn
,
err
:=
tls
.
DialWithDialer
(
dialer
,
"tcp"
,
addr
,
tlsConfig
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"tls dial: %w"
,
err
)
return
fmt
.
Errorf
(
"tls dial: %w"
,
err
)
}
}
_
=
conn
.
SetDeadline
(
time
.
Now
()
.
Add
(
smtpIOTimeout
))
defer
func
()
{
_
=
conn
.
Close
()
}()
defer
func
()
{
_
=
conn
.
Close
()
}()
client
,
err
:=
smtp
.
NewClient
(
conn
,
host
)
client
,
err
:=
smtp
.
NewClient
(
conn
,
host
)
...
@@ -254,6 +321,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
...
@@ -254,6 +321,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code
:
code
,
Code
:
code
,
Attempts
:
0
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
(),
CreatedAt
:
time
.
Now
(),
ExpiresAt
:
time
.
Now
()
.
Add
(
verifyCodeTTL
),
}
}
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save verify code: %w"
,
err
)
return
fmt
.
Errorf
(
"save verify code: %w"
,
err
)
...
@@ -286,8 +354,12 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
...
@@ -286,8 +354,12 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
remaining
:=
time
.
Until
(
data
.
ExpiresAt
)
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
if
remaining
<=
0
{
return
ErrInvalidVerifyCode
}
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
remaining
);
err
!=
nil
{
slog
.
Error
(
"failed to update verification attempt count"
,
"email"
,
email
,
"error"
,
err
)
}
}
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
return
ErrVerifyCodeMaxAttempts
...
@@ -297,7 +369,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
...
@@ -297,7 +369,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码
// 验证成功,删除验证码
if
err
:=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
);
err
!=
nil
{
if
err
:=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] F
ailed to delete verification code after success
: %v
"
,
err
)
s
log
.
Error
(
"f
ailed to delete verification code after success
"
,
"email"
,
email
,
"error
"
,
err
)
}
}
return
nil
return
nil
}
}
...
@@ -447,7 +519,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
...
@@ -447,7 +519,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] P
assword reset email skipped
(
cooldown
): %s
"
,
email
)
s
log
.
Info
(
"p
assword reset email skipped
due to
cooldown
"
,
"email
"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
return
nil
// Silent success to prevent revealing cooldown to attackers
}
}
...
@@ -458,7 +530,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
...
@@ -458,7 +530,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration)
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] F
ailed to set password reset cooldown
for %s: %v"
,
email
,
err
)
s
log
.
Error
(
"f
ailed to set password reset cooldown
"
,
"email"
,
email
,
"error"
,
err
)
}
}
return
nil
return
nil
...
@@ -488,7 +560,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
...
@@ -488,7 +560,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use)
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] F
ailed to delete password reset token after consumption
: %v
"
,
err
)
s
log
.
Error
(
"f
ailed to delete password reset token after consumption
"
,
"email"
,
email
,
"error
"
,
err
)
}
}
return
nil
return
nil
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment