Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3fcb0cc3
Commit
3fcb0cc3
authored
Feb 10, 2026
by
yangjianbo
Browse files
feat(subscription): 有界队列执行维护并改进鉴权解析
parent
2bfb1629
Changes
13
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/wire.go
View file @
3fcb0cc3
...
@@ -76,6 +76,7 @@ func provideCleanup(
...
@@ -76,6 +76,7 @@ func provideCleanup(
pricing
*
service
.
PricingService
,
pricing
*
service
.
PricingService
,
emailQueue
*
service
.
EmailQueueService
,
emailQueue
*
service
.
EmailQueueService
,
billingCache
*
service
.
BillingCacheService
,
billingCache
*
service
.
BillingCacheService
,
subscriptionService
*
service
.
SubscriptionService
,
oauth
*
service
.
OAuthService
,
oauth
*
service
.
OAuthService
,
openaiOAuth
*
service
.
OpenAIOAuthService
,
openaiOAuth
*
service
.
OpenAIOAuthService
,
geminiOAuth
*
service
.
GeminiOAuthService
,
geminiOAuth
*
service
.
GeminiOAuthService
,
...
@@ -150,6 +151,12 @@ func provideCleanup(
...
@@ -150,6 +151,12 @@ func provideCleanup(
subscriptionExpiry
.
Stop
()
subscriptionExpiry
.
Stop
()
return
nil
return
nil
}},
}},
{
"SubscriptionService"
,
func
()
error
{
if
subscriptionService
!=
nil
{
subscriptionService
.
Stop
()
}
return
nil
}},
{
"PricingService"
,
func
()
error
{
{
"PricingService"
,
func
()
error
{
pricing
.
Stop
()
pricing
.
Stop
()
return
nil
return
nil
...
...
backend/cmd/server/wire_gen.go
View file @
3fcb0cc3
...
@@ -204,7 +204,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -204,7 +204,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
soraAccountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
compositeTokenCacheInvalidator
,
schedulerCache
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
soraAccountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
compositeTokenCacheInvalidator
,
schedulerCache
,
configConfig
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
subscriptionExpiryService
:=
service
.
ProvideSubscriptionExpiryService
(
userSubscriptionRepository
)
subscriptionExpiryService
:=
service
.
ProvideSubscriptionExpiryService
(
userSubscriptionRepository
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
soraMediaCleanupService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
subscriptionExpiryService
,
usageCleanupService
,
pricingService
,
emailQueueService
,
billingCacheService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
soraMediaCleanupService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
subscriptionExpiryService
,
usageCleanupService
,
pricingService
,
emailQueueService
,
billingCacheService
,
subscriptionService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
application
:=
&
Application
{
application
:=
&
Application
{
Server
:
httpServer
,
Server
:
httpServer
,
Cleanup
:
v
,
Cleanup
:
v
,
...
@@ -243,6 +243,7 @@ func provideCleanup(
...
@@ -243,6 +243,7 @@ func provideCleanup(
pricing
*
service
.
PricingService
,
pricing
*
service
.
PricingService
,
emailQueue
*
service
.
EmailQueueService
,
emailQueue
*
service
.
EmailQueueService
,
billingCache
*
service
.
BillingCacheService
,
billingCache
*
service
.
BillingCacheService
,
subscriptionService
*
service
.
SubscriptionService
,
oauth
*
service
.
OAuthService
,
oauth
*
service
.
OAuthService
,
openaiOAuth
*
service
.
OpenAIOAuthService
,
openaiOAuth
*
service
.
OpenAIOAuthService
,
geminiOAuth
*
service
.
GeminiOAuthService
,
geminiOAuth
*
service
.
GeminiOAuthService
,
...
@@ -316,6 +317,12 @@ func provideCleanup(
...
@@ -316,6 +317,12 @@ func provideCleanup(
subscriptionExpiry
.
Stop
()
subscriptionExpiry
.
Stop
()
return
nil
return
nil
}},
}},
{
"SubscriptionService"
,
func
()
error
{
if
subscriptionService
!=
nil
{
subscriptionService
.
Stop
()
}
return
nil
}},
{
"PricingService"
,
func
()
error
{
{
"PricingService"
,
func
()
error
{
pricing
.
Stop
()
pricing
.
Stop
()
return
nil
return
nil
...
...
backend/internal/config/config.go
View file @
3fcb0cc3
...
@@ -38,33 +38,34 @@ const (
...
@@ -38,33 +38,34 @@ const (
)
)
type
Config
struct
{
type
Config
struct
{
Server
ServerConfig
`mapstructure:"server"`
Server
ServerConfig
`mapstructure:"server"`
CORS
CORSConfig
`mapstructure:"cors"`
CORS
CORSConfig
`mapstructure:"cors"`
Security
SecurityConfig
`mapstructure:"security"`
Security
SecurityConfig
`mapstructure:"security"`
Billing
BillingConfig
`mapstructure:"billing"`
Billing
BillingConfig
`mapstructure:"billing"`
Turnstile
TurnstileConfig
`mapstructure:"turnstile"`
Turnstile
TurnstileConfig
`mapstructure:"turnstile"`
Database
DatabaseConfig
`mapstructure:"database"`
Database
DatabaseConfig
`mapstructure:"database"`
Redis
RedisConfig
`mapstructure:"redis"`
Redis
RedisConfig
`mapstructure:"redis"`
Ops
OpsConfig
`mapstructure:"ops"`
Ops
OpsConfig
`mapstructure:"ops"`
JWT
JWTConfig
`mapstructure:"jwt"`
JWT
JWTConfig
`mapstructure:"jwt"`
Totp
TotpConfig
`mapstructure:"totp"`
Totp
TotpConfig
`mapstructure:"totp"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
Default
DefaultConfig
`mapstructure:"default"`
Default
DefaultConfig
`mapstructure:"default"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
APIKeyAuth
APIKeyAuthCacheConfig
`mapstructure:"api_key_auth_cache"`
APIKeyAuth
APIKeyAuthCacheConfig
`mapstructure:"api_key_auth_cache"`
SubscriptionCache
SubscriptionCacheConfig
`mapstructure:"subscription_cache"`
SubscriptionCache
SubscriptionCacheConfig
`mapstructure:"subscription_cache"`
Dashboard
DashboardCacheConfig
`mapstructure:"dashboard_cache"`
SubscriptionMaintenance
SubscriptionMaintenanceConfig
`mapstructure:"subscription_maintenance"`
DashboardAgg
DashboardAggregationConfig
`mapstructure:"dashboard_aggregation"`
Dashboard
DashboardCacheConfig
`mapstructure:"dashboard_cache"`
UsageCleanup
UsageCleanupConfig
`mapstructure:"usage_cleanup"`
DashboardAgg
DashboardAggregationConfig
`mapstructure:"dashboard_aggregation"`
Concurrency
ConcurrencyConfig
`mapstructure:"concurrency"`
UsageCleanup
UsageCleanupConfig
`mapstructure:"usage_cleanup"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
Concurrency
ConcurrencyConfig
`mapstructure:"concurrency"`
Sora
SoraConfig
`mapstructure:"sora"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
RunMode
string
`mapstructure:"run_mode" yaml:"run_mode"`
Sora
SoraConfig
`mapstructure:"sora"`
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
RunMode
string
`mapstructure:"run_mode" yaml:"run_mode"`
Gemini
GeminiConfig
`mapstructure:"gemini"`
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
Update
UpdateConfig
`mapstructure:"update"`
Gemini
GeminiConfig
`mapstructure:"gemini"`
Update
UpdateConfig
`mapstructure:"update"`
}
}
type
GeminiConfig
struct
{
type
GeminiConfig
struct
{
...
@@ -609,6 +610,13 @@ type SubscriptionCacheConfig struct {
...
@@ -609,6 +610,13 @@ type SubscriptionCacheConfig struct {
JitterPercent
int
`mapstructure:"jitter_percent"`
JitterPercent
int
`mapstructure:"jitter_percent"`
}
}
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。
// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。
type
SubscriptionMaintenanceConfig
struct
{
WorkerCount
int
`mapstructure:"worker_count"`
QueueSize
int
`mapstructure:"queue_size"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
// DashboardCacheConfig 仪表盘统计缓存配置
type
DashboardCacheConfig
struct
{
type
DashboardCacheConfig
struct
{
// Enabled: 是否启用仪表盘缓存
// Enabled: 是否启用仪表盘缓存
...
@@ -734,15 +742,6 @@ func Load() (*Config, error) {
...
@@ -734,15 +742,6 @@ func Load() (*Config, error) {
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
)
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
)
cfg
.
Security
.
CSP
.
Policy
=
strings
.
TrimSpace
(
cfg
.
Security
.
CSP
.
Policy
)
cfg
.
Security
.
CSP
.
Policy
=
strings
.
TrimSpace
(
cfg
.
Security
.
CSP
.
Policy
)
if
cfg
.
JWT
.
Secret
==
""
{
secret
,
err
:=
generateJWTSecret
(
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate jwt secret error: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
log
.
Println
(
"Warning: JWT secret auto-generated. Consider setting a fixed secret for production."
)
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg
.
Totp
.
EncryptionKey
=
strings
.
TrimSpace
(
cfg
.
Totp
.
EncryptionKey
)
cfg
.
Totp
.
EncryptionKey
=
strings
.
TrimSpace
(
cfg
.
Totp
.
EncryptionKey
)
if
cfg
.
Totp
.
EncryptionKey
==
""
{
if
cfg
.
Totp
.
EncryptionKey
==
""
{
...
@@ -1057,9 +1056,30 @@ func setDefaults() {
...
@@ -1057,9 +1056,30 @@ func setDefaults() {
// Security - proxy fallback
// Security - proxy fallback
viper
.
SetDefault
(
"security.proxy_fallback.allow_direct_on_error"
,
false
)
viper
.
SetDefault
(
"security.proxy_fallback.allow_direct_on_error"
,
false
)
// Subscription Maintenance (bounded queue + worker pool)
viper
.
SetDefault
(
"subscription_maintenance.worker_count"
,
2
)
viper
.
SetDefault
(
"subscription_maintenance.queue_size"
,
1024
)
}
}
func
(
c
*
Config
)
Validate
()
error
{
func
(
c
*
Config
)
Validate
()
error
{
jwtSecret
:=
strings
.
TrimSpace
(
c
.
JWT
.
Secret
)
if
jwtSecret
==
""
{
return
fmt
.
Errorf
(
"jwt.secret is required"
)
}
// NOTE: 按 UTF-8 编码后的字节长度计算。
// 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。
if
len
([]
byte
(
jwtSecret
))
<
32
{
return
fmt
.
Errorf
(
"jwt.secret must be at least 32 bytes"
)
}
if
c
.
SubscriptionMaintenance
.
WorkerCount
<
0
{
return
fmt
.
Errorf
(
"subscription_maintenance.worker_count must be non-negative"
)
}
if
c
.
SubscriptionMaintenance
.
QueueSize
<
0
{
return
fmt
.
Errorf
(
"subscription_maintenance.queue_size must be non-negative"
)
}
// Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。
// Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
geminiClientID
:=
strings
.
TrimSpace
(
c
.
Gemini
.
OAuth
.
ClientID
)
geminiClientID
:=
strings
.
TrimSpace
(
c
.
Gemini
.
OAuth
.
ClientID
)
...
...
backend/internal/config/config_test.go
View file @
3fcb0cc3
...
@@ -8,6 +8,12 @@ import (
...
@@ -8,6 +8,12 @@ import (
"github.com/spf13/viper"
"github.com/spf13/viper"
)
)
func
resetViperWithJWTSecret
(
t
*
testing
.
T
)
{
t
.
Helper
()
viper
.
Reset
()
t
.
Setenv
(
"JWT_SECRET"
,
strings
.
Repeat
(
"x"
,
32
))
}
func
TestNormalizeRunMode
(
t
*
testing
.
T
)
{
func
TestNormalizeRunMode
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
tests
:=
[]
struct
{
input
string
input
string
...
@@ -29,7 +35,7 @@ func TestNormalizeRunMode(t *testing.T) {
...
@@ -29,7 +35,7 @@ func TestNormalizeRunMode(t *testing.T) {
}
}
func
TestLoadDefaultSchedulingConfig
(
t
*
testing
.
T
)
{
func
TestLoadDefaultSchedulingConfig
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -57,7 +63,7 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
...
@@ -57,7 +63,7 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
func
TestLoadSchedulingConfigFromEnv
(
t
*
testing
.
T
)
{
func
TestLoadSchedulingConfigFromEnv
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
t
.
Setenv
(
"GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING"
,
"5"
)
t
.
Setenv
(
"GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING"
,
"5"
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
...
@@ -71,7 +77,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
...
@@ -71,7 +77,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
func
TestLoadDefaultSecurityToggles
(
t
*
testing
.
T
)
{
func
TestLoadDefaultSecurityToggles
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -93,7 +99,7 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
...
@@ -93,7 +99,7 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
}
}
func
TestLoadDefaultServerMode
(
t
*
testing
.
T
)
{
func
TestLoadDefaultServerMode
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -106,7 +112,7 @@ func TestLoadDefaultServerMode(t *testing.T) {
...
@@ -106,7 +112,7 @@ func TestLoadDefaultServerMode(t *testing.T) {
}
}
func
TestLoadDefaultDatabaseSSLMode
(
t
*
testing
.
T
)
{
func
TestLoadDefaultDatabaseSSLMode
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -119,7 +125,7 @@ func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
...
@@ -119,7 +125,7 @@ func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
}
}
func
TestValidateLinuxDoFrontendRedirectURL
(
t
*
testing
.
T
)
{
func
TestValidateLinuxDoFrontendRedirectURL
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -144,7 +150,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
...
@@ -144,7 +150,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
func
TestValidateLinuxDoPKCERequiredForPublicClient
(
t
*
testing
.
T
)
{
func
TestValidateLinuxDoPKCERequiredForPublicClient
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -169,7 +175,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
...
@@ -169,7 +175,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
}
}
func
TestLoadDefaultDashboardCacheConfig
(
t
*
testing
.
T
)
{
func
TestLoadDefaultDashboardCacheConfig
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -194,7 +200,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
...
@@ -194,7 +200,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
}
}
func
TestValidateDashboardCacheConfigEnabled
(
t
*
testing
.
T
)
{
func
TestValidateDashboardCacheConfigEnabled
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -214,7 +220,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
...
@@ -214,7 +220,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
}
}
func
TestValidateDashboardCacheConfigDisabled
(
t
*
testing
.
T
)
{
func
TestValidateDashboardCacheConfigDisabled
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -233,7 +239,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
...
@@ -233,7 +239,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
}
}
func
TestLoadDefaultDashboardAggregationConfig
(
t
*
testing
.
T
)
{
func
TestLoadDefaultDashboardAggregationConfig
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -270,7 +276,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
...
@@ -270,7 +276,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
}
}
func
TestValidateDashboardAggregationConfigDisabled
(
t
*
testing
.
T
)
{
func
TestValidateDashboardAggregationConfigDisabled
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -289,7 +295,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
...
@@ -289,7 +295,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
}
}
func
TestValidateDashboardAggregationBackfillMaxDays
(
t
*
testing
.
T
)
{
func
TestValidateDashboardAggregationBackfillMaxDays
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -308,7 +314,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
...
@@ -308,7 +314,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
}
}
func
TestLoadDefaultUsageCleanupConfig
(
t
*
testing
.
T
)
{
func
TestLoadDefaultUsageCleanupConfig
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -333,7 +339,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
...
@@ -333,7 +339,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
}
}
func
TestValidateUsageCleanupConfigEnabled
(
t
*
testing
.
T
)
{
func
TestValidateUsageCleanupConfigEnabled
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -352,7 +358,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
...
@@ -352,7 +358,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
}
}
func
TestValidateUsageCleanupConfigDisabled
(
t
*
testing
.
T
)
{
func
TestValidateUsageCleanupConfigDisabled
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -451,7 +457,7 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
...
@@ -451,7 +457,7 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
}
}
func
TestValidateServerFrontendURL
(
t
*
testing
.
T
)
{
func
TestValidateServerFrontendURL
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -505,6 +511,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
...
@@ -505,6 +511,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
func
TestWarnIfInsecureURL
(
t
*
testing
.
T
)
{
func
TestWarnIfInsecureURL
(
t
*
testing
.
T
)
{
warnIfInsecureURL
(
"test"
,
"http://example.com"
)
warnIfInsecureURL
(
"test"
,
"http://example.com"
)
warnIfInsecureURL
(
"test"
,
"bad://url"
)
warnIfInsecureURL
(
"test"
,
"bad://url"
)
warnIfInsecureURL
(
"test"
,
"://invalid"
)
}
}
func
TestGenerateJWTSecretDefaultLength
(
t
*
testing
.
T
)
{
func
TestGenerateJWTSecretDefaultLength
(
t
*
testing
.
T
)
{
...
@@ -518,7 +525,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
...
@@ -518,7 +525,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
}
}
func
TestValidateOpsCleanupScheduleRequired
(
t
*
testing
.
T
)
{
func
TestValidateOpsCleanupScheduleRequired
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -536,7 +543,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
...
@@ -536,7 +543,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
}
}
func
TestValidateConcurrencyPingInterval
(
t
*
testing
.
T
)
{
func
TestValidateConcurrencyPingInterval
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -553,14 +560,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
...
@@ -553,14 +560,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
}
}
func
TestProvideConfig
(
t
*
testing
.
T
)
{
func
TestProvideConfig
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
if
_
,
err
:=
ProvideConfig
();
err
!=
nil
{
if
_
,
err
:=
ProvideConfig
();
err
!=
nil
{
t
.
Fatalf
(
"ProvideConfig() error: %v"
,
err
)
t
.
Fatalf
(
"ProvideConfig() error: %v"
,
err
)
}
}
}
}
func
TestValidateConfigWithLinuxDoEnabled
(
t
*
testing
.
T
)
{
func
TestValidateConfigWithLinuxDoEnabled
(
t
*
testing
.
T
)
{
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -604,6 +611,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
...
@@ -604,6 +611,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
}
}
}
}
func
TestDatabaseDSNWithTimezone_WithPassword
(
t
*
testing
.
T
)
{
d
:=
&
DatabaseConfig
{
Host
:
"localhost"
,
Port
:
5432
,
User
:
"u"
,
Password
:
"p"
,
DBName
:
"db"
,
SSLMode
:
"prefer"
,
}
got
:=
d
.
DSNWithTimezone
(
"UTC"
)
if
!
strings
.
Contains
(
got
,
"password=p"
)
{
t
.
Fatalf
(
"DSNWithTimezone should include password: %q"
,
got
)
}
if
!
strings
.
Contains
(
got
,
"TimeZone=UTC"
)
{
t
.
Fatalf
(
"DSNWithTimezone should include TimeZone=UTC: %q"
,
got
)
}
}
func
TestValidateAbsoluteHTTPURLMissingHost
(
t
*
testing
.
T
)
{
func
TestValidateAbsoluteHTTPURLMissingHost
(
t
*
testing
.
T
)
{
if
err
:=
ValidateAbsoluteHTTPURL
(
"https://"
);
err
==
nil
{
if
err
:=
ValidateAbsoluteHTTPURL
(
"https://"
);
err
==
nil
{
t
.
Fatalf
(
"ValidateAbsoluteHTTPURL should reject missing host"
)
t
.
Fatalf
(
"ValidateAbsoluteHTTPURL should reject missing host"
)
...
@@ -626,10 +651,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
...
@@ -626,10 +651,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL
(
"secure"
,
"https://example.com"
)
warnIfInsecureURL
(
"secure"
,
"https://example.com"
)
}
}
func
TestValidateJWTSecret_UTF8Bytes
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
// 31 bytes (< 32) even though it's 31 characters.
cfg
.
JWT
.
Secret
=
strings
.
Repeat
(
"a"
,
31
)
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() should reject 31-byte secret"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"at least 32 bytes"
)
{
t
.
Fatalf
(
"Validate() error = %v"
,
err
)
}
// 32 bytes OK.
cfg
.
JWT
.
Secret
=
strings
.
Repeat
(
"a"
,
32
)
err
=
cfg
.
Validate
()
if
err
!=
nil
{
t
.
Fatalf
(
"Validate() should accept 32-byte secret: %v"
,
err
)
}
}
func
TestValidateConfigErrors
(
t
*
testing
.
T
)
{
func
TestValidateConfigErrors
(
t
*
testing
.
T
)
{
buildValid
:=
func
(
t
*
testing
.
T
)
*
Config
{
buildValid
:=
func
(
t
*
testing
.
T
)
*
Config
{
t
.
Helper
()
t
.
Helper
()
viper
.
Res
et
()
resetViperWithJWTSecr
et
(
t
)
cfg
,
err
:=
Load
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
t
.
Fatalf
(
"Load() error: %v"
,
err
)
...
@@ -642,6 +692,26 @@ func TestValidateConfigErrors(t *testing.T) {
...
@@ -642,6 +692,26 @@ func TestValidateConfigErrors(t *testing.T) {
mutate
func
(
*
Config
)
mutate
func
(
*
Config
)
wantErr
string
wantErr
string
}{
}{
{
name
:
"jwt secret required"
,
mutate
:
func
(
c
*
Config
)
{
c
.
JWT
.
Secret
=
""
},
wantErr
:
"jwt.secret is required"
,
},
{
name
:
"jwt secret min bytes"
,
mutate
:
func
(
c
*
Config
)
{
c
.
JWT
.
Secret
=
strings
.
Repeat
(
"a"
,
31
)
},
wantErr
:
"jwt.secret must be at least 32 bytes"
,
},
{
name
:
"subscription maintenance worker_count non-negative"
,
mutate
:
func
(
c
*
Config
)
{
c
.
SubscriptionMaintenance
.
WorkerCount
=
-
1
},
wantErr
:
"subscription_maintenance.worker_count"
,
},
{
name
:
"subscription maintenance queue_size non-negative"
,
mutate
:
func
(
c
*
Config
)
{
c
.
SubscriptionMaintenance
.
QueueSize
=
-
1
},
wantErr
:
"subscription_maintenance.queue_size"
,
},
{
{
name
:
"jwt expire hour positive"
,
name
:
"jwt expire hour positive"
,
mutate
:
func
(
c
*
Config
)
{
c
.
JWT
.
ExpireHour
=
0
},
mutate
:
func
(
c
*
Config
)
{
c
.
JWT
.
ExpireHour
=
0
},
...
...
backend/internal/server/middleware/admin_auth.go
View file @
3fcb0cc3
...
@@ -58,8 +58,13 @@ func adminAuth(
...
@@ -58,8 +58,13 @@ func adminAuth(
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
if
authHeader
!=
""
{
if
authHeader
!=
""
{
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
parts
[
0
]
==
"Bearer"
{
if
len
(
parts
)
==
2
&&
strings
.
EqualFold
(
parts
[
0
],
"Bearer"
)
{
if
!
validateJWTForAdmin
(
c
,
parts
[
1
],
authService
,
userService
)
{
token
:=
strings
.
TrimSpace
(
parts
[
1
])
if
token
==
""
{
AbortWithError
(
c
,
401
,
"UNAUTHORIZED"
,
"Authorization required"
)
return
}
if
!
validateJWTForAdmin
(
c
,
token
,
authService
,
userService
)
{
return
return
}
}
c
.
Next
()
c
.
Next
()
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
3fcb0cc3
...
@@ -35,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
...
@@ -35,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
if
authHeader
!=
""
{
if
authHeader
!=
""
{
// 验证Bearer scheme
// 验证Bearer scheme
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
parts
[
0
]
==
"Bearer"
{
if
len
(
parts
)
==
2
&&
strings
.
EqualFold
(
parts
[
0
]
,
"Bearer"
)
{
apiKeyString
=
parts
[
1
]
apiKeyString
=
strings
.
TrimSpace
(
parts
[
1
]
)
}
}
}
}
...
@@ -166,7 +166,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
...
@@ -166,7 +166,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if
needsMaintenance
{
if
needsMaintenance
{
maintenanceCopy
:=
*
subscription
maintenanceCopy
:=
*
subscription
go
subscriptionService
.
DoWindowMaintenance
(
&
maintenanceCopy
)
subscriptionService
.
DoWindowMaintenance
(
&
maintenanceCopy
)
}
}
}
else
{
}
else
{
// 余额模式:检查用户余额
// 余额模式:检查用户余额
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
3fcb0cc3
...
@@ -57,6 +57,57 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -57,6 +57,57 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
},
},
}
}
t
.
Run
(
"standard_mode_needs_maintenance_does_not_block_request"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
cfg
.
SubscriptionMaintenance
.
WorkerCount
=
1
cfg
.
SubscriptionMaintenance
.
QueueSize
=
1
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
past
:=
time
.
Now
()
.
Add
(
-
48
*
time
.
Hour
)
sub
:=
&
service
.
UserSubscription
{
ID
:
55
,
UserID
:
user
.
ID
,
GroupID
:
group
.
ID
,
Status
:
service
.
SubscriptionStatusActive
,
ExpiresAt
:
time
.
Now
()
.
Add
(
24
*
time
.
Hour
),
DailyWindowStart
:
&
past
,
DailyUsageUSD
:
0
,
}
maintenanceCalled
:=
make
(
chan
struct
{},
1
)
subscriptionRepo
:=
&
stubUserSubscriptionRepo
{
getActive
:
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
clone
:=
*
sub
return
&
clone
,
nil
},
updateStatus
:
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
return
nil
},
activateWindow
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetDaily
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
maintenanceCalled
<-
struct
{}{}
return
nil
},
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
,
nil
,
cfg
)
t
.
Cleanup
(
subscriptionService
.
Stop
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
select
{
case
<-
maintenanceCalled
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatalf
(
"expected maintenance to be scheduled"
)
}
})
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
...
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
})
t
.
Run
(
"simple_mode_accepts_lowercase_bearer"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"bearer "
+
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"standard_mode_enforces_quota_check"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"standard_mode_enforces_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
...
...
backend/internal/server/middleware/jwt_auth.go
View file @
3fcb0cc3
...
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
...
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
// 验证Bearer scheme
// 验证Bearer scheme
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
!=
2
||
parts
[
0
]
!=
"Bearer"
{
if
len
(
parts
)
!=
2
||
!
strings
.
EqualFold
(
parts
[
0
]
,
"Bearer"
)
{
AbortWithError
(
c
,
401
,
"INVALID_AUTH_HEADER"
,
"Authorization header format must be 'Bearer {token}'"
)
AbortWithError
(
c
,
401
,
"INVALID_AUTH_HEADER"
,
"Authorization header format must be 'Bearer {token}'"
)
return
return
}
}
tokenString
:=
parts
[
1
]
tokenString
:=
strings
.
TrimSpace
(
parts
[
1
]
)
if
tokenString
==
""
{
if
tokenString
==
""
{
AbortWithError
(
c
,
401
,
"EMPTY_TOKEN"
,
"Token cannot be empty"
)
AbortWithError
(
c
,
401
,
"EMPTY_TOKEN"
,
"Token cannot be empty"
)
return
return
...
...
backend/internal/server/middleware/jwt_auth_test.go
View file @
3fcb0cc3
...
@@ -84,6 +84,28 @@ func TestJWTAuth_ValidToken(t *testing.T) {
...
@@ -84,6 +84,28 @@ func TestJWTAuth_ValidToken(t *testing.T) {
require
.
Equal
(
t
,
"user"
,
body
[
"role"
])
require
.
Equal
(
t
,
"user"
,
body
[
"role"
])
}
}
func
TestJWTAuth_ValidToken_LowercaseBearer
(
t
*
testing
.
T
)
{
user
:=
&
service
.
User
{
ID
:
1
,
Email
:
"test@example.com"
,
Role
:
"user"
,
Status
:
service
.
StatusActive
,
Concurrency
:
5
,
TokenVersion
:
1
,
}
router
,
authSvc
:=
newJWTTestEnv
(
map
[
int64
]
*
service
.
User
{
1
:
user
})
token
,
err
:=
authSvc
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/protected"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestJWTAuth_MissingAuthorizationHeader
(
t
*
testing
.
T
)
{
func
TestJWTAuth_MissingAuthorizationHeader
(
t
*
testing
.
T
)
{
router
,
_
:=
newJWTTestEnv
(
nil
)
router
,
_
:=
newJWTTestEnv
(
nil
)
...
...
backend/internal/server/middleware/misc_coverage_test.go
0 → 100644
View file @
3fcb0cc3
//go:build unit
package
middleware
import
(
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestClientRequestID_GeneratesWhenMissing
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ClientRequestID
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
v
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
)
require
.
NotNil
(
t
,
v
)
id
,
ok
:=
v
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
NotEmpty
(
t
,
id
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestClientRequestID_PreservesExisting
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ClientRequestID
())
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
id
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"keep"
,
id
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
ClientRequestID
,
"keep"
))
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestRequestBodyLimit_LimitsBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
RequestBodyLimit
(
4
))
r
.
POST
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
_
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
require
.
Error
(
t
,
err
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/t"
,
bytes
.
NewBufferString
(
"12345"
))
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestForcePlatform_SetsContextAndGinValue
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
r
.
Use
(
ForcePlatform
(
"anthropic"
))
r
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
require
.
True
(
t
,
HasForcePlatform
(
c
))
v
,
ok
:=
GetForcePlatformFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"anthropic"
,
v
)
ctxV
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ForcePlatform
)
require
.
Equal
(
t
,
"anthropic"
,
ctxV
)
c
.
Status
(
http
.
StatusOK
)
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAuthSubjectHelpers_RoundTrip
(
t
*
testing
.
T
)
{
c
:=
&
gin
.
Context
{}
c
.
Set
(
string
(
ContextKeyUser
),
AuthSubject
{
UserID
:
1
,
Concurrency
:
2
})
c
.
Set
(
string
(
ContextKeyUserRole
),
"admin"
)
sub
,
ok
:=
GetAuthSubjectFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
1
),
sub
.
UserID
)
require
.
Equal
(
t
,
2
,
sub
.
Concurrency
)
role
,
ok
:=
GetUserRoleFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"admin"
,
role
)
}
func
TestAPIKeyAndSubscriptionFromContext
(
t
*
testing
.
T
)
{
c
:=
&
gin
.
Context
{}
key
:=
&
service
.
APIKey
{
ID
:
1
}
c
.
Set
(
string
(
ContextKeyAPIKey
),
key
)
gotKey
,
ok
:=
GetAPIKeyFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
1
),
gotKey
.
ID
)
sub
:=
&
service
.
UserSubscription
{
ID
:
2
}
c
.
Set
(
string
(
ContextKeySubscription
),
sub
)
gotSub
,
ok
:=
GetSubscriptionFromContext
(
c
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
int64
(
2
),
gotSub
.
ID
)
}
backend/internal/service/subscription_maintenance_queue.go
0 → 100644
View file @
3fcb0cc3
package
service
import
(
"fmt"
"log"
"sync"
)
// SubscriptionMaintenanceQueue 提供“有界队列 + 固定 worker”的后台执行器。
// 用于从请求热路径触发维护动作时,避免无限 goroutine 膨胀。
type
SubscriptionMaintenanceQueue
struct
{
queue
chan
func
()
wg
sync
.
WaitGroup
stop
sync
.
Once
}
func
NewSubscriptionMaintenanceQueue
(
workerCount
,
queueSize
int
)
*
SubscriptionMaintenanceQueue
{
if
workerCount
<=
0
{
workerCount
=
1
}
if
queueSize
<=
0
{
queueSize
=
1
}
q
:=
&
SubscriptionMaintenanceQueue
{
queue
:
make
(
chan
func
(),
queueSize
),
}
q
.
wg
.
Add
(
workerCount
)
for
i
:=
0
;
i
<
workerCount
;
i
++
{
go
func
(
workerID
int
)
{
defer
q
.
wg
.
Done
()
for
fn
:=
range
q
.
queue
{
func
()
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
log
.
Printf
(
"SubscriptionMaintenance worker panic: %v"
,
r
)
}
}()
fn
()
}()
}
}(
i
)
}
return
q
}
// TryEnqueue 尝试将任务入队。
// 当队列已满时返回 error(调用方应该选择跳过并记录告警/限频日志)。
func
(
q
*
SubscriptionMaintenanceQueue
)
TryEnqueue
(
task
func
())
error
{
if
q
==
nil
{
return
fmt
.
Errorf
(
"maintenance queue is nil"
)
}
if
task
==
nil
{
return
fmt
.
Errorf
(
"maintenance task is nil"
)
}
select
{
case
q
.
queue
<-
task
:
return
nil
default
:
return
fmt
.
Errorf
(
"maintenance queue full"
)
}
}
func
(
q
*
SubscriptionMaintenanceQueue
)
Stop
()
{
if
q
==
nil
{
return
}
q
.
stop
.
Do
(
func
()
{
close
(
q
.
queue
)
q
.
wg
.
Wait
()
})
}
backend/internal/service/subscription_maintenance_queue_test.go
0 → 100644
View file @
3fcb0cc3
//go:build unit
package
service
import
(
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func
TestSubscriptionMaintenanceQueue_TryEnqueue_QueueFull
(
t
*
testing
.
T
)
{
q
:=
NewSubscriptionMaintenanceQueue
(
1
,
1
)
t
.
Cleanup
(
q
.
Stop
)
block
:=
make
(
chan
struct
{})
var
started
atomic
.
Int32
require
.
NoError
(
t
,
q
.
TryEnqueue
(
func
()
{
started
.
Store
(
1
)
<-
block
}))
// Wait until worker started consuming the first task.
require
.
Eventually
(
t
,
func
()
bool
{
return
started
.
Load
()
==
1
},
time
.
Second
,
10
*
time
.
Millisecond
)
// Queue size is 1; with the worker blocked, enqueueing one more should fill it.
require
.
NoError
(
t
,
q
.
TryEnqueue
(
func
()
{}))
// Now the queue is full; next enqueue must fail.
err
:=
q
.
TryEnqueue
(
func
()
{})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"full"
)
close
(
block
)
}
func
TestSubscriptionMaintenanceQueue_TryEnqueue_PanicDoesNotKillWorker
(
t
*
testing
.
T
)
{
q
:=
NewSubscriptionMaintenanceQueue
(
1
,
8
)
t
.
Cleanup
(
q
.
Stop
)
require
.
NoError
(
t
,
q
.
TryEnqueue
(
func
()
{
panic
(
"boom"
)
}))
done
:=
make
(
chan
struct
{})
require
.
NoError
(
t
,
q
.
TryEnqueue
(
func
()
{
close
(
done
)
}))
select
{
case
<-
done
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatalf
(
"worker did not continue after panic"
)
}
}
backend/internal/service/subscription_service.go
View file @
3fcb0cc3
...
@@ -48,6 +48,8 @@ type SubscriptionService struct {
...
@@ -48,6 +48,8 @@ type SubscriptionService struct {
subCacheGroup
singleflight
.
Group
subCacheGroup
singleflight
.
Group
subCacheTTL
time
.
Duration
subCacheTTL
time
.
Duration
subCacheJitter
int
// 抖动百分比
subCacheJitter
int
// 抖动百分比
maintenanceQueue
*
SubscriptionMaintenanceQueue
}
}
// NewSubscriptionService 创建订阅服务
// NewSubscriptionService 创建订阅服务
...
@@ -59,9 +61,31 @@ func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscript
...
@@ -59,9 +61,31 @@ func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscript
entClient
:
entClient
,
entClient
:
entClient
,
}
}
svc
.
initSubCache
(
cfg
)
svc
.
initSubCache
(
cfg
)
svc
.
initMaintenanceQueue
(
cfg
)
return
svc
return
svc
}
}
func
(
s
*
SubscriptionService
)
initMaintenanceQueue
(
cfg
*
config
.
Config
)
{
if
cfg
==
nil
{
return
}
mc
:=
cfg
.
SubscriptionMaintenance
if
mc
.
WorkerCount
<=
0
||
mc
.
QueueSize
<=
0
{
return
}
s
.
maintenanceQueue
=
NewSubscriptionMaintenanceQueue
(
mc
.
WorkerCount
,
mc
.
QueueSize
)
}
// Stop stops the maintenance worker pool.
func
(
s
*
SubscriptionService
)
Stop
()
{
if
s
==
nil
{
return
}
if
s
.
maintenanceQueue
!=
nil
{
s
.
maintenanceQueue
.
Stop
()
}
}
// initSubCache 初始化订阅 L1 缓存
// initSubCache 初始化订阅 L1 缓存
func
(
s
*
SubscriptionService
)
initSubCache
(
cfg
*
config
.
Config
)
{
func
(
s
*
SubscriptionService
)
initSubCache
(
cfg
*
config
.
Config
)
{
if
cfg
==
nil
{
if
cfg
==
nil
{
...
@@ -720,6 +744,23 @@ func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, grou
...
@@ -720,6 +744,23 @@ func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, grou
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func
(
s
*
SubscriptionService
)
DoWindowMaintenance
(
sub
*
UserSubscription
)
{
func
(
s
*
SubscriptionService
)
DoWindowMaintenance
(
sub
*
UserSubscription
)
{
if
s
==
nil
{
return
}
if
s
.
maintenanceQueue
!=
nil
{
err
:=
s
.
maintenanceQueue
.
TryEnqueue
(
func
()
{
s
.
doWindowMaintenance
(
sub
)
})
if
err
!=
nil
{
log
.
Printf
(
"Subscription maintenance enqueue failed: %v"
,
err
)
}
return
}
s
.
doWindowMaintenance
(
sub
)
}
func
(
s
*
SubscriptionService
)
doWindowMaintenance
(
sub
*
UserSubscription
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
defer
cancel
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment