Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
46dda583
Commit
46dda583
authored
Jan 04, 2026
by
shaw
Browse files
Merge PR #146: feat: 提升流式网关稳定性与安全策略强化
parents
27ed042c
f8e7255c
Changes
62
Hide whitespace changes
Inline
Side-by-side
README.md
View file @
46dda583
...
...
@@ -268,6 +268,15 @@ default:
rate_multiplier
:
1.0
```
Additional security-related options are available in
`config.yaml`
:
-
`cors.allowed_origins`
for CORS allowlist
-
`security.url_allowlist`
for upstream/pricing/CRS host allowlists
-
`security.csp`
to control Content-Security-Policy headers
-
`billing.circuit_breaker`
to fail closed on billing errors
-
`server.trusted_proxies`
to enable X-Forwarded-For parsing
-
`turnstile.required`
to require Turnstile in release mode
```
bash
# 6. Run the application
./sub2api
...
...
README_CN.md
View file @
46dda583
...
...
@@ -268,6 +268,15 @@ default:
rate_multiplier
:
1.0
```
`config.yaml`
还支持以下安全相关配置:
-
`cors.allowed_origins`
配置 CORS 白名单
-
`security.url_allowlist`
配置上游/价格数据/CRS 主机白名单
-
`security.csp`
配置 Content-Security-Policy
-
`billing.circuit_breaker`
计费异常时 fail-closed
-
`server.trusted_proxies`
启用可信代理解析 X-Forwarded-For
-
`turnstile.required`
在 release 模式强制启用 Turnstile
```
bash
# 6. 运行应用
./sub2api
...
...
backend/cmd/server/main.go
View file @
46dda583
...
...
@@ -86,7 +86,8 @@ func main() {
func
runSetupServer
()
{
r
:=
gin
.
New
()
r
.
Use
(
middleware
.
Recovery
())
r
.
Use
(
middleware
.
CORS
())
r
.
Use
(
middleware
.
CORS
(
config
.
CORSConfig
{}))
r
.
Use
(
middleware
.
SecurityHeaders
(
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
config
.
DefaultCSPPolicy
}))
// Register setup routes
setup
.
RegisterRoutes
(
r
)
...
...
backend/cmd/server/wire_gen.go
View file @
46dda583
...
...
@@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler
:=
admin
.
NewDashboardHandler
(
dashboardService
)
accountRepository
:=
repository
.
NewAccountRepository
(
client
,
db
)
proxyRepository
:=
repository
.
NewProxyRepository
(
client
,
db
)
proxyExitInfoProber
:=
repository
.
NewProxyExitInfoProber
()
proxyExitInfoProber
:=
repository
.
NewProxyExitInfoProber
(
configConfig
)
adminService
:=
service
.
NewAdminService
(
userRepository
,
groupRepository
,
accountRepository
,
proxyRepository
,
apiKeyRepository
,
redeemCodeRepository
,
billingCacheService
,
proxyExitInfoProber
)
adminUserHandler
:=
admin
.
NewUserHandler
(
adminService
)
groupHandler
:=
admin
.
NewGroupHandler
(
adminService
)
...
...
@@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityTokenProvider
:=
service
.
NewAntigravityTokenProvider
(
accountRepository
,
geminiTokenCache
,
antigravityOAuthService
)
httpUpstream
:=
repository
.
NewHTTPUpstream
(
configConfig
)
antigravityGatewayService
:=
service
.
NewAntigravityGatewayService
(
accountRepository
,
gatewayCache
,
antigravityTokenProvider
,
rateLimitService
,
httpUpstream
,
settingService
)
accountTestService
:=
service
.
NewAccountTestService
(
accountRepository
,
geminiTokenProvider
,
antigravityGatewayService
,
httpUpstream
)
accountTestService
:=
service
.
NewAccountTestService
(
accountRepository
,
geminiTokenProvider
,
antigravityGatewayService
,
httpUpstream
,
configConfig
)
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyService
:=
service
.
Provide
ConcurrencyService
(
concurrencyCache
,
accountRepository
,
configConfig
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
)
concurrencyService
:=
service
.
New
ConcurrencyService
(
concurrencyCache
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
openAIOAuthHandler
:=
admin
.
NewOpenAIOAuthHandler
(
openAIOAuthService
,
adminService
)
...
...
@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeService
:=
service
.
NewUserAttributeService
(
userAttributeDefinitionRepository
,
userAttributeValueRepository
)
userAttributeHandler
:=
admin
.
NewUserAttributeHandler
(
userAttributeService
)
adminHandlers
:=
handler
.
ProvideAdminHandlers
(
dashboardHandler
,
adminUserHandler
,
groupHandler
,
accountHandler
,
oAuthHandler
,
openAIOAuthHandler
,
geminiOAuthHandler
,
antigravityOAuthHandler
,
proxyHandler
,
adminRedeemHandler
,
settingHandler
,
systemHandler
,
adminSubscriptionHandler
,
adminUsageHandler
,
userAttributeHandler
)
pricingRemoteClient
:=
repository
.
NewPricingRemoteClient
()
pricingRemoteClient
:=
repository
.
NewPricingRemoteClient
(
configConfig
)
pricingService
,
err
:=
service
.
ProvidePricingService
(
configConfig
,
pricingRemoteClient
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService
:=
service
.
ProvideTimingWheelService
()
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
)
gatewayHandler
:=
handler
.
NewGatewayHandler
(
gatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
,
userService
,
concurrencyService
,
billingCacheService
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
gatewayHandler
:=
handler
.
NewGatewayHandler
(
gatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
,
userService
,
concurrencyService
,
billingCacheService
,
configConfig
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
openAIGatewayHandler
:=
handler
.
NewOpenAIGatewayHandler
(
openAIGatewayService
,
concurrencyService
,
billingCacheService
)
openAIGatewayHandler
:=
handler
.
NewOpenAIGatewayHandler
(
openAIGatewayService
,
concurrencyService
,
billingCacheService
,
configConfig
)
handlerSettingHandler
:=
handler
.
ProvideSettingHandler
(
settingService
,
buildInfo
)
handlers
:=
handler
.
ProvideHandlers
(
authHandler
,
userHandler
,
apiKeyHandler
,
usageHandler
,
redeemHandler
,
subscriptionHandler
,
adminHandlers
,
gatewayHandler
,
openAIGatewayHandler
,
handlerSettingHandler
)
jwtAuthMiddleware
:=
middleware
.
NewJWTAuthMiddleware
(
authService
,
userService
)
...
...
backend/internal/config/config.go
View file @
46dda583
...
...
@@ -2,7 +2,10 @@
package
config
import
(
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"strings"
"time"
...
...
@@ -14,6 +17,8 @@ const (
RunModeSimple
=
"simple"
)
const
DefaultCSPPolicy
=
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const
(
...
...
@@ -30,6 +35,10 @@ const (
type
Config
struct
{
Server
ServerConfig
`mapstructure:"server"`
CORS
CORSConfig
`mapstructure:"cors"`
Security
SecurityConfig
`mapstructure:"security"`
Billing
BillingConfig
`mapstructure:"billing"`
Turnstile
TurnstileConfig
`mapstructure:"turnstile"`
Database
DatabaseConfig
`mapstructure:"database"`
Redis
RedisConfig
`mapstructure:"redis"`
JWT
JWTConfig
`mapstructure:"jwt"`
...
...
@@ -37,6 +46,7 @@ type Config struct {
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
Concurrency
ConcurrencyConfig
`mapstructure:"concurrency"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
RunMode
string
`mapstructure:"run_mode" yaml:"run_mode"`
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
...
...
@@ -95,11 +105,61 @@ type PricingConfig struct {
}
type
ServerConfig
struct
{
Host
string
`mapstructure:"host"`
Port
int
`mapstructure:"port"`
Mode
string
`mapstructure:"mode"`
// debug/release
ReadHeaderTimeout
int
`mapstructure:"read_header_timeout"`
// 读取请求头超时(秒)
IdleTimeout
int
`mapstructure:"idle_timeout"`
// 空闲连接超时(秒)
Host
string
`mapstructure:"host"`
Port
int
`mapstructure:"port"`
Mode
string
`mapstructure:"mode"`
// debug/release
ReadHeaderTimeout
int
`mapstructure:"read_header_timeout"`
// 读取请求头超时(秒)
IdleTimeout
int
`mapstructure:"idle_timeout"`
// 空闲连接超时(秒)
TrustedProxies
[]
string
`mapstructure:"trusted_proxies"`
// 可信代理列表(CIDR/IP)
}
type
CORSConfig
struct
{
AllowedOrigins
[]
string
`mapstructure:"allowed_origins"`
AllowCredentials
bool
`mapstructure:"allow_credentials"`
}
type
SecurityConfig
struct
{
URLAllowlist
URLAllowlistConfig
`mapstructure:"url_allowlist"`
ResponseHeaders
ResponseHeaderConfig
`mapstructure:"response_headers"`
CSP
CSPConfig
`mapstructure:"csp"`
ProxyProbe
ProxyProbeConfig
`mapstructure:"proxy_probe"`
}
type
URLAllowlistConfig
struct
{
UpstreamHosts
[]
string
`mapstructure:"upstream_hosts"`
PricingHosts
[]
string
`mapstructure:"pricing_hosts"`
CRSHosts
[]
string
`mapstructure:"crs_hosts"`
AllowPrivateHosts
bool
`mapstructure:"allow_private_hosts"`
}
type
ResponseHeaderConfig
struct
{
AdditionalAllowed
[]
string
`mapstructure:"additional_allowed"`
ForceRemove
[]
string
`mapstructure:"force_remove"`
}
type
CSPConfig
struct
{
Enabled
bool
`mapstructure:"enabled"`
Policy
string
`mapstructure:"policy"`
}
type
ProxyProbeConfig
struct
{
InsecureSkipVerify
bool
`mapstructure:"insecure_skip_verify"`
}
type
BillingConfig
struct
{
CircuitBreaker
CircuitBreakerConfig
`mapstructure:"circuit_breaker"`
}
type
CircuitBreakerConfig
struct
{
Enabled
bool
`mapstructure:"enabled"`
FailureThreshold
int
`mapstructure:"failure_threshold"`
ResetTimeoutSeconds
int
`mapstructure:"reset_timeout_seconds"`
HalfOpenRequests
int
`mapstructure:"half_open_requests"`
}
type
ConcurrencyConfig
struct
{
// PingInterval: 并发等待期间的 SSE ping 间隔(秒)
PingInterval
int
`mapstructure:"ping_interval"`
}
// GatewayConfig API网关相关配置
...
...
@@ -134,6 +194,13 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout
int
`mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
StreamKeepaliveInterval
int
`mapstructure:"stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
MaxLineSize
int
`mapstructure:"max_line_size"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody
bool
`mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
...
...
@@ -237,6 +304,10 @@ type JWTConfig struct {
ExpireHour
int
`mapstructure:"expire_hour"`
}
type
TurnstileConfig
struct
{
Required
bool
`mapstructure:"required"`
}
type
DefaultConfig
struct
{
AdminEmail
string
`mapstructure:"admin_email"`
AdminPassword
string
`mapstructure:"admin_password"`
...
...
@@ -287,11 +358,39 @@ func Load() (*Config, error) {
}
cfg
.
RunMode
=
NormalizeRunMode
(
cfg
.
RunMode
)
cfg
.
Server
.
Mode
=
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
Server
.
Mode
))
if
cfg
.
Server
.
Mode
==
""
{
cfg
.
Server
.
Mode
=
"debug"
}
cfg
.
JWT
.
Secret
=
strings
.
TrimSpace
(
cfg
.
JWT
.
Secret
)
cfg
.
CORS
.
AllowedOrigins
=
normalizeStringSlice
(
cfg
.
CORS
.
AllowedOrigins
)
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
)
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
)
cfg
.
Security
.
CSP
.
Policy
=
strings
.
TrimSpace
(
cfg
.
Security
.
CSP
.
Policy
)
if
cfg
.
Server
.
Mode
!=
"release"
&&
cfg
.
JWT
.
Secret
==
""
{
secret
,
err
:=
generateJWTSecret
(
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate jwt secret error: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
log
.
Println
(
"Warning: JWT secret auto-generated for non-release mode. Do not use in production."
)
}
if
err
:=
cfg
.
Validate
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"validate config error: %w"
,
err
)
}
if
cfg
.
Server
.
Mode
!=
"release"
&&
cfg
.
JWT
.
Secret
!=
""
&&
isWeakJWTSecret
(
cfg
.
JWT
.
Secret
)
{
log
.
Println
(
"Warning: JWT secret appears weak; use a 32+ character random secret in production."
)
}
if
len
(
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
)
>
0
||
len
(
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
)
>
0
{
log
.
Printf
(
"AUDIT: response header policy configured additional_allowed=%v force_remove=%v"
,
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
,
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
,
)
}
return
&
cfg
,
nil
}
...
...
@@ -304,6 +403,39 @@ func setDefaults() {
viper
.
SetDefault
(
"server.mode"
,
"debug"
)
viper
.
SetDefault
(
"server.read_header_timeout"
,
30
)
// 30秒读取请求头
viper
.
SetDefault
(
"server.idle_timeout"
,
120
)
// 120秒空闲超时
viper
.
SetDefault
(
"server.trusted_proxies"
,
[]
string
{})
// CORS
viper
.
SetDefault
(
"cors.allowed_origins"
,
[]
string
{})
viper
.
SetDefault
(
"cors.allow_credentials"
,
true
)
// Security
viper
.
SetDefault
(
"security.url_allowlist.upstream_hosts"
,
[]
string
{
"api.openai.com"
,
"api.anthropic.com"
,
"generativelanguage.googleapis.com"
,
"cloudcode-pa.googleapis.com"
,
"*.openai.azure.com"
,
})
viper
.
SetDefault
(
"security.url_allowlist.pricing_hosts"
,
[]
string
{
"raw.githubusercontent.com"
,
})
viper
.
SetDefault
(
"security.url_allowlist.crs_hosts"
,
[]
string
{})
viper
.
SetDefault
(
"security.url_allowlist.allow_private_hosts"
,
false
)
viper
.
SetDefault
(
"security.response_headers.additional_allowed"
,
[]
string
{})
viper
.
SetDefault
(
"security.response_headers.force_remove"
,
[]
string
{})
viper
.
SetDefault
(
"security.csp.enabled"
,
true
)
viper
.
SetDefault
(
"security.csp.policy"
,
DefaultCSPPolicy
)
viper
.
SetDefault
(
"security.proxy_probe.insecure_skip_verify"
,
false
)
// Billing
viper
.
SetDefault
(
"billing.circuit_breaker.enabled"
,
true
)
viper
.
SetDefault
(
"billing.circuit_breaker.failure_threshold"
,
5
)
viper
.
SetDefault
(
"billing.circuit_breaker.reset_timeout_seconds"
,
30
)
viper
.
SetDefault
(
"billing.circuit_breaker.half_open_requests"
,
3
)
// Turnstile
viper
.
SetDefault
(
"turnstile.required"
,
false
)
// Database
viper
.
SetDefault
(
"database.host"
,
"localhost"
)
...
...
@@ -329,7 +461,7 @@ func setDefaults() {
viper
.
SetDefault
(
"redis.min_idle_conns"
,
10
)
// JWT
viper
.
SetDefault
(
"jwt.secret"
,
"
change-me-in-production
"
)
viper
.
SetDefault
(
"jwt.secret"
,
""
)
viper
.
SetDefault
(
"jwt.expire_hour"
,
24
)
// Default
...
...
@@ -357,7 +489,7 @@ func setDefaults() {
viper
.
SetDefault
(
"timezone"
,
"Asia/Shanghai"
)
// Gateway
viper
.
SetDefault
(
"gateway.response_header_timeout"
,
3
00
)
//
3
00秒(
5
分钟)等待上游响应头,LLM高负载时可能排队较久
viper
.
SetDefault
(
"gateway.response_header_timeout"
,
6
00
)
//
6
00秒(
10
分钟)等待上游响应头,LLM高负载时可能排队较久
viper
.
SetDefault
(
"gateway.log_upstream_error_body"
,
false
)
viper
.
SetDefault
(
"gateway.log_upstream_error_body_max_bytes"
,
2048
)
viper
.
SetDefault
(
"gateway.inject_beta_for_apikey"
,
false
)
...
...
@@ -365,19 +497,23 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper
.
SetDefault
(
"gateway.max_idle_conns"
,
240
)
// 最大空闲连接总数(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_idle_conns_per_host"
,
120
)
// 每主机最大空闲连接(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_conns_per_host"
,
240
)
// 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.idle_conn_timeout_seconds"
,
30
0
)
// 空闲连接超时(秒)
viper
.
SetDefault
(
"gateway.max_idle_conns"
,
240
)
// 最大空闲连接总数(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_idle_conns_per_host"
,
120
)
// 每主机最大空闲连接(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_conns_per_host"
,
240
)
// 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.idle_conn_timeout_seconds"
,
9
0
)
// 空闲连接超时(秒)
viper
.
SetDefault
(
"gateway.max_upstream_clients"
,
5000
)
viper
.
SetDefault
(
"gateway.client_idle_ttl_seconds"
,
900
)
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
15
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
30
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.stream_data_interval_timeout"
,
180
)
viper
.
SetDefault
(
"gateway.stream_keepalive_interval"
,
10
)
viper
.
SetDefault
(
"gateway.max_line_size"
,
10
*
1024
*
1024
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_max_waiting"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_wait_timeout"
,
45
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_wait_timeout"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_max_waiting"
,
100
)
viper
.
SetDefault
(
"gateway.scheduling.load_batch_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.scheduling.slot_cleanup_interval"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"concurrency.ping_interval"
,
10
)
// TokenRefresh
viper
.
SetDefault
(
"token_refresh.enabled"
,
true
)
...
...
@@ -396,11 +532,39 @@ func setDefaults() {
}
func
(
c
*
Config
)
Validate
()
error
{
if
c
.
JWT
.
Secret
==
""
{
return
fmt
.
Errorf
(
"jwt.secret is required"
)
if
c
.
Server
.
Mode
==
"release"
{
if
c
.
JWT
.
Secret
==
""
{
return
fmt
.
Errorf
(
"jwt.secret is required in release mode"
)
}
if
len
(
c
.
JWT
.
Secret
)
<
32
{
return
fmt
.
Errorf
(
"jwt.secret must be at least 32 characters"
)
}
if
isWeakJWTSecret
(
c
.
JWT
.
Secret
)
{
return
fmt
.
Errorf
(
"jwt.secret is too weak"
)
}
}
if
c
.
JWT
.
ExpireHour
<=
0
{
return
fmt
.
Errorf
(
"jwt.expire_hour must be positive"
)
}
if
c
.
JWT
.
ExpireHour
>
168
{
return
fmt
.
Errorf
(
"jwt.expire_hour must be <= 168 (7 days)"
)
}
if
c
.
JWT
.
Secret
==
"change-me-in-production"
&&
c
.
Server
.
Mode
==
"release"
{
return
fmt
.
Errorf
(
"jwt.secret must be changed in production"
)
if
c
.
JWT
.
ExpireHour
>
24
{
log
.
Printf
(
"Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security."
,
c
.
JWT
.
ExpireHour
)
}
if
c
.
Security
.
CSP
.
Enabled
&&
strings
.
TrimSpace
(
c
.
Security
.
CSP
.
Policy
)
==
""
{
return
fmt
.
Errorf
(
"security.csp.policy is required when CSP is enabled"
)
}
if
c
.
Billing
.
CircuitBreaker
.
Enabled
{
if
c
.
Billing
.
CircuitBreaker
.
FailureThreshold
<=
0
{
return
fmt
.
Errorf
(
"billing.circuit_breaker.failure_threshold must be positive"
)
}
if
c
.
Billing
.
CircuitBreaker
.
ResetTimeoutSeconds
<=
0
{
return
fmt
.
Errorf
(
"billing.circuit_breaker.reset_timeout_seconds must be positive"
)
}
if
c
.
Billing
.
CircuitBreaker
.
HalfOpenRequests
<=
0
{
return
fmt
.
Errorf
(
"billing.circuit_breaker.half_open_requests must be positive"
)
}
}
if
c
.
Database
.
MaxOpenConns
<=
0
{
return
fmt
.
Errorf
(
"database.max_open_conns must be positive"
)
...
...
@@ -458,6 +622,9 @@ func (c *Config) Validate() error {
if
c
.
Gateway
.
IdleConnTimeoutSeconds
<=
0
{
return
fmt
.
Errorf
(
"gateway.idle_conn_timeout_seconds must be positive"
)
}
if
c
.
Gateway
.
IdleConnTimeoutSeconds
>
180
{
log
.
Printf
(
"Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse."
,
c
.
Gateway
.
IdleConnTimeoutSeconds
)
}
if
c
.
Gateway
.
MaxUpstreamClients
<=
0
{
return
fmt
.
Errorf
(
"gateway.max_upstream_clients must be positive"
)
}
...
...
@@ -467,6 +634,26 @@ func (c *Config) Validate() error {
if
c
.
Gateway
.
ConcurrencySlotTTLMinutes
<=
0
{
return
fmt
.
Errorf
(
"gateway.concurrency_slot_ttl_minutes must be positive"
)
}
if
c
.
Gateway
.
StreamDataIntervalTimeout
<
0
{
return
fmt
.
Errorf
(
"gateway.stream_data_interval_timeout must be non-negative"
)
}
if
c
.
Gateway
.
StreamDataIntervalTimeout
!=
0
&&
(
c
.
Gateway
.
StreamDataIntervalTimeout
<
30
||
c
.
Gateway
.
StreamDataIntervalTimeout
>
300
)
{
return
fmt
.
Errorf
(
"gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds"
)
}
if
c
.
Gateway
.
StreamKeepaliveInterval
<
0
{
return
fmt
.
Errorf
(
"gateway.stream_keepalive_interval must be non-negative"
)
}
if
c
.
Gateway
.
StreamKeepaliveInterval
!=
0
&&
(
c
.
Gateway
.
StreamKeepaliveInterval
<
5
||
c
.
Gateway
.
StreamKeepaliveInterval
>
30
)
{
return
fmt
.
Errorf
(
"gateway.stream_keepalive_interval must be 0 or between 5-30 seconds"
)
}
if
c
.
Gateway
.
MaxLineSize
<
0
{
return
fmt
.
Errorf
(
"gateway.max_line_size must be non-negative"
)
}
if
c
.
Gateway
.
MaxLineSize
!=
0
&&
c
.
Gateway
.
MaxLineSize
<
1024
*
1024
{
return
fmt
.
Errorf
(
"gateway.max_line_size must be at least 1MB"
)
}
if
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_max_waiting must be positive"
)
}
...
...
@@ -482,9 +669,57 @@ func (c *Config) Validate() error {
if
c
.
Gateway
.
Scheduling
.
SlotCleanupInterval
<
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.slot_cleanup_interval must be non-negative"
)
}
if
c
.
Concurrency
.
PingInterval
<
5
||
c
.
Concurrency
.
PingInterval
>
30
{
return
fmt
.
Errorf
(
"concurrency.ping_interval must be between 5-30 seconds"
)
}
return
nil
}
func
normalizeStringSlice
(
values
[]
string
)
[]
string
{
if
len
(
values
)
==
0
{
return
values
}
normalized
:=
make
([]
string
,
0
,
len
(
values
))
for
_
,
v
:=
range
values
{
trimmed
:=
strings
.
TrimSpace
(
v
)
if
trimmed
==
""
{
continue
}
normalized
=
append
(
normalized
,
trimmed
)
}
return
normalized
}
func
isWeakJWTSecret
(
secret
string
)
bool
{
lower
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
secret
))
if
lower
==
""
{
return
true
}
weak
:=
map
[
string
]
struct
{}{
"change-me-in-production"
:
{},
"changeme"
:
{},
"secret"
:
{},
"password"
:
{},
"123456"
:
{},
"12345678"
:
{},
"admin"
:
{},
"jwt-secret"
:
{},
}
_
,
exists
:=
weak
[
lower
]
return
exists
}
func
generateJWTSecret
(
byteLength
int
)
(
string
,
error
)
{
if
byteLength
<=
0
{
byteLength
=
32
}
buf
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
buf
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
buf
),
nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
...
...
backend/internal/handler/admin/setting_handler.go
View file @
46dda583
package
admin
import
(
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -34,31 +38,31 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
}
response
.
Success
(
c
,
dto
.
SystemSettings
{
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
SMTPHost
:
settings
.
SMTPHost
,
SMTPPort
:
settings
.
SMTPPort
,
SMTPUsername
:
settings
.
SMTPUsername
,
SMTPPassword
:
settings
.
SMTPPassword
,
SMTPFrom
:
settings
.
SMTPFrom
,
SMTPFromName
:
settings
.
SMTPFromName
,
SMTPUseTLS
:
settings
.
SMTPUseTLS
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
TurnstileSecretKey
:
settings
.
TurnstileSecretKey
,
SiteName
:
settings
.
SiteName
,
SiteLogo
:
settings
.
SiteLogo
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
APIBaseURL
:
settings
.
APIBaseURL
,
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
DefaultConcurrency
:
settings
.
DefaultConcurrency
,
DefaultBalance
:
settings
.
DefaultBalance
,
EnableModelFallback
:
settings
.
EnableModelFallback
,
FallbackModelAnthropic
:
settings
.
FallbackModelAnthropic
,
FallbackModelOpenAI
:
settings
.
FallbackModelOpenAI
,
FallbackModelGemini
:
settings
.
FallbackModelGemini
,
FallbackModelAntigravity
:
settings
.
FallbackModelAntigravity
,
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
SMTPHost
:
settings
.
SMTPHost
,
SMTPPort
:
settings
.
SMTPPort
,
SMTPUsername
:
settings
.
SMTPUsername
,
SMTPPassword
Configured
:
settings
.
SMTPPassword
Configured
,
SMTPFrom
:
settings
.
SMTPFrom
,
SMTPFromName
:
settings
.
SMTPFromName
,
SMTPUseTLS
:
settings
.
SMTPUseTLS
,
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
TurnstileSecretKey
Configured
:
settings
.
TurnstileSecretKey
Configured
,
SiteName
:
settings
.
SiteName
,
SiteLogo
:
settings
.
SiteLogo
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
APIBaseURL
:
settings
.
APIBaseURL
,
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
DefaultConcurrency
:
settings
.
DefaultConcurrency
,
DefaultBalance
:
settings
.
DefaultBalance
,
EnableModelFallback
:
settings
.
EnableModelFallback
,
FallbackModelAnthropic
:
settings
.
FallbackModelAnthropic
,
FallbackModelOpenAI
:
settings
.
FallbackModelOpenAI
,
FallbackModelGemini
:
settings
.
FallbackModelGemini
,
FallbackModelAntigravity
:
settings
.
FallbackModelAntigravity
,
})
}
...
...
@@ -111,6 +115,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
previousSettings
,
err
:=
h
.
settingService
.
GetAllSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// 验证参数
if
req
.
DefaultConcurrency
<
1
{
req
.
DefaultConcurrency
=
1
...
...
@@ -185,6 +195,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
h
.
auditSettingsUpdate
(
c
,
previousSettings
,
settings
,
req
)
// 重新获取设置返回
updatedSettings
,
err
:=
h
.
settingService
.
GetAllSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
...
...
@@ -193,34 +205,134 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
response
.
Success
(
c
,
dto
.
SystemSettings
{
RegistrationEnabled
:
updatedSettings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
updatedSettings
.
EmailVerifyEnabled
,
SMTPHost
:
updatedSettings
.
SMTPHost
,
SMTPPort
:
updatedSettings
.
SMTPPort
,
SMTPUsername
:
updatedSettings
.
SMTPUsername
,
SMTPPassword
:
updatedSettings
.
SMTPPassword
,
SMTPFrom
:
updatedSettings
.
SMTPFrom
,
SMTPFromName
:
updatedSettings
.
SMTPFromName
,
SMTPUseTLS
:
updatedSettings
.
SMTPUseTLS
,
TurnstileEnabled
:
updatedSettings
.
TurnstileEnabled
,
TurnstileSiteKey
:
updatedSettings
.
TurnstileSiteKey
,
TurnstileSecretKey
:
updatedSettings
.
TurnstileSecretKey
,
SiteName
:
updatedSettings
.
SiteName
,
SiteLogo
:
updatedSettings
.
SiteLogo
,
SiteSubtitle
:
updatedSettings
.
SiteSubtitle
,
APIBaseURL
:
updatedSettings
.
APIBaseURL
,
ContactInfo
:
updatedSettings
.
ContactInfo
,
DocURL
:
updatedSettings
.
DocURL
,
DefaultConcurrency
:
updatedSettings
.
DefaultConcurrency
,
DefaultBalance
:
updatedSettings
.
DefaultBalance
,
EnableModelFallback
:
updatedSettings
.
EnableModelFallback
,
FallbackModelAnthropic
:
updatedSettings
.
FallbackModelAnthropic
,
FallbackModelOpenAI
:
updatedSettings
.
FallbackModelOpenAI
,
FallbackModelGemini
:
updatedSettings
.
FallbackModelGemini
,
FallbackModelAntigravity
:
updatedSettings
.
FallbackModelAntigravity
,
RegistrationEnabled
:
updatedSettings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
updatedSettings
.
EmailVerifyEnabled
,
SMTPHost
:
updatedSettings
.
SMTPHost
,
SMTPPort
:
updatedSettings
.
SMTPPort
,
SMTPUsername
:
updatedSettings
.
SMTPUsername
,
SMTPPassword
Configured
:
updatedSettings
.
SMTPPassword
Configured
,
SMTPFrom
:
updatedSettings
.
SMTPFrom
,
SMTPFromName
:
updatedSettings
.
SMTPFromName
,
SMTPUseTLS
:
updatedSettings
.
SMTPUseTLS
,
TurnstileEnabled
:
updatedSettings
.
TurnstileEnabled
,
TurnstileSiteKey
:
updatedSettings
.
TurnstileSiteKey
,
TurnstileSecretKey
Configured
:
updatedSettings
.
TurnstileSecretKey
Configured
,
SiteName
:
updatedSettings
.
SiteName
,
SiteLogo
:
updatedSettings
.
SiteLogo
,
SiteSubtitle
:
updatedSettings
.
SiteSubtitle
,
APIBaseURL
:
updatedSettings
.
APIBaseURL
,
ContactInfo
:
updatedSettings
.
ContactInfo
,
DocURL
:
updatedSettings
.
DocURL
,
DefaultConcurrency
:
updatedSettings
.
DefaultConcurrency
,
DefaultBalance
:
updatedSettings
.
DefaultBalance
,
EnableModelFallback
:
updatedSettings
.
EnableModelFallback
,
FallbackModelAnthropic
:
updatedSettings
.
FallbackModelAnthropic
,
FallbackModelOpenAI
:
updatedSettings
.
FallbackModelOpenAI
,
FallbackModelGemini
:
updatedSettings
.
FallbackModelGemini
,
FallbackModelAntigravity
:
updatedSettings
.
FallbackModelAntigravity
,
})
}
func
(
h
*
SettingHandler
)
auditSettingsUpdate
(
c
*
gin
.
Context
,
before
*
service
.
SystemSettings
,
after
*
service
.
SystemSettings
,
req
UpdateSettingsRequest
)
{
if
before
==
nil
||
after
==
nil
{
return
}
changed
:=
diffSettings
(
before
,
after
,
req
)
if
len
(
changed
)
==
0
{
return
}
subject
,
_
:=
middleware
.
GetAuthSubjectFromContext
(
c
)
role
,
_
:=
middleware
.
GetUserRoleFromContext
(
c
)
log
.
Printf
(
"AUDIT: settings updated at=%s user_id=%d role=%s changed=%v"
,
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
),
subject
.
UserID
,
role
,
changed
,
)
}
func
diffSettings
(
before
*
service
.
SystemSettings
,
after
*
service
.
SystemSettings
,
req
UpdateSettingsRequest
)
[]
string
{
changed
:=
make
([]
string
,
0
,
20
)
if
before
.
RegistrationEnabled
!=
after
.
RegistrationEnabled
{
changed
=
append
(
changed
,
"registration_enabled"
)
}
if
before
.
EmailVerifyEnabled
!=
after
.
EmailVerifyEnabled
{
changed
=
append
(
changed
,
"email_verify_enabled"
)
}
if
before
.
SMTPHost
!=
after
.
SMTPHost
{
changed
=
append
(
changed
,
"smtp_host"
)
}
if
before
.
SMTPPort
!=
after
.
SMTPPort
{
changed
=
append
(
changed
,
"smtp_port"
)
}
if
before
.
SMTPUsername
!=
after
.
SMTPUsername
{
changed
=
append
(
changed
,
"smtp_username"
)
}
if
req
.
SMTPPassword
!=
""
{
changed
=
append
(
changed
,
"smtp_password"
)
}
if
before
.
SMTPFrom
!=
after
.
SMTPFrom
{
changed
=
append
(
changed
,
"smtp_from_email"
)
}
if
before
.
SMTPFromName
!=
after
.
SMTPFromName
{
changed
=
append
(
changed
,
"smtp_from_name"
)
}
if
before
.
SMTPUseTLS
!=
after
.
SMTPUseTLS
{
changed
=
append
(
changed
,
"smtp_use_tls"
)
}
if
before
.
TurnstileEnabled
!=
after
.
TurnstileEnabled
{
changed
=
append
(
changed
,
"turnstile_enabled"
)
}
if
before
.
TurnstileSiteKey
!=
after
.
TurnstileSiteKey
{
changed
=
append
(
changed
,
"turnstile_site_key"
)
}
if
req
.
TurnstileSecretKey
!=
""
{
changed
=
append
(
changed
,
"turnstile_secret_key"
)
}
if
before
.
SiteName
!=
after
.
SiteName
{
changed
=
append
(
changed
,
"site_name"
)
}
if
before
.
SiteLogo
!=
after
.
SiteLogo
{
changed
=
append
(
changed
,
"site_logo"
)
}
if
before
.
SiteSubtitle
!=
after
.
SiteSubtitle
{
changed
=
append
(
changed
,
"site_subtitle"
)
}
if
before
.
APIBaseURL
!=
after
.
APIBaseURL
{
changed
=
append
(
changed
,
"api_base_url"
)
}
if
before
.
ContactInfo
!=
after
.
ContactInfo
{
changed
=
append
(
changed
,
"contact_info"
)
}
if
before
.
DocURL
!=
after
.
DocURL
{
changed
=
append
(
changed
,
"doc_url"
)
}
if
before
.
DefaultConcurrency
!=
after
.
DefaultConcurrency
{
changed
=
append
(
changed
,
"default_concurrency"
)
}
if
before
.
DefaultBalance
!=
after
.
DefaultBalance
{
changed
=
append
(
changed
,
"default_balance"
)
}
if
before
.
EnableModelFallback
!=
after
.
EnableModelFallback
{
changed
=
append
(
changed
,
"enable_model_fallback"
)
}
if
before
.
FallbackModelAnthropic
!=
after
.
FallbackModelAnthropic
{
changed
=
append
(
changed
,
"fallback_model_anthropic"
)
}
if
before
.
FallbackModelOpenAI
!=
after
.
FallbackModelOpenAI
{
changed
=
append
(
changed
,
"fallback_model_openai"
)
}
if
before
.
FallbackModelGemini
!=
after
.
FallbackModelGemini
{
changed
=
append
(
changed
,
"fallback_model_gemini"
)
}
if
before
.
FallbackModelAntigravity
!=
after
.
FallbackModelAntigravity
{
changed
=
append
(
changed
,
"fallback_model_antigravity"
)
}
return
changed
}
// TestSMTPRequest 测试SMTP连接请求
type
TestSMTPRequest
struct
{
SMTPHost
string
`json:"smtp_host" binding:"required"`
...
...
backend/internal/handler/dto/settings.go
View file @
46dda583
...
...
@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
SMTPHost
string
`json:"smtp_host"`
SMTPPort
int
`json:"smtp_port"`
SMTPUsername
string
`json:"smtp_username"`
SMTPPassword
string
`json:"smtp_password
,omitempty
"`
SMTPFrom
string
`json:"smtp_from_email"`
SMTPFromName
string
`json:"smtp_from_name"`
SMTPUseTLS
bool
`json:"smtp_use_tls"`
SMTPHost
string
`json:"smtp_host"`
SMTPPort
int
`json:"smtp_port"`
SMTPUsername
string
`json:"smtp_username"`
SMTPPassword
Configured
bool
`json:"smtp_password
_configured
"`
SMTPFrom
string
`json:"smtp_from_email"`
SMTPFromName
string
`json:"smtp_from_name"`
SMTPUseTLS
bool
`json:"smtp_use_tls"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKey
string
`json:"turnstile_secret_key
,omitempty
"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKey
Configured
bool
`json:"turnstile_secret_key
_configured
"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
...
...
backend/internal/handler/gateway_handler.go
View file @
46dda583
...
...
@@ -11,8 +11,10 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -38,14 +40,19 @@ func NewGatewayHandler(
userService
*
service
.
UserService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
cfg
*
config
.
Config
,
)
*
GatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
if
cfg
!=
nil
{
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
}
return
&
GatewayHandler
{
gatewayService
:
gatewayService
,
geminiCompatService
:
geminiCompatService
,
antigravityGatewayService
:
antigravityGatewayService
,
userService
:
userService
,
billingCacheService
:
billingCacheService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
),
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
,
pingInterval
),
}
}
...
...
@@ -121,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
return
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
defer
userReleaseFunc
()
}
...
...
@@ -128,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
log
.
Printf
(
"Billing eligibility check failed after wait: %v"
,
err
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusForbidden
,
"billing_error"
,
err
.
Error
(),
streamStarted
)
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
@@ -220,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
...
...
@@ -344,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
...
...
@@ -674,7 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusForbidden
,
"billing_error"
,
err
.
Error
())
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
errorResponse
(
c
,
status
,
code
,
message
)
return
}
...
...
@@ -800,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
func
billingErrorDetails
(
err
error
)
(
status
int
,
code
,
message
string
)
{
if
errors
.
Is
(
err
,
service
.
ErrBillingServiceUnavailable
)
{
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
msg
=
"Billing service temporarily unavailable. Please retry later."
}
return
http
.
StatusServiceUnavailable
,
"billing_service_error"
,
msg
}
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
msg
=
err
.
Error
()
}
return
http
.
StatusForbidden
,
"billing_error"
,
msg
}
backend/internal/handler/gateway_helper.go
View file @
46dda583
...
...
@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -26,8 +27,8 @@ import (
const
(
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait
=
30
*
time
.
Second
//
p
ingInterval 流式响应等待时发送 ping 的间隔
p
ingInterval
=
1
5
*
time
.
Second
//
defaultP
ingInterval 流式响应等待时发送 ping 的
默认
间隔
defaultP
ingInterval
=
1
0
*
time
.
Second
// initialBackoff 初始退避时间
initialBackoff
=
100
*
time
.
Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
...
...
@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude
SSEPingFormat
=
"data: {
\"
type
\"
:
\"
ping
\"
}
\n\n
"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone
SSEPingFormat
=
""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment
SSEPingFormat
=
":
\n\n
"
)
// ConcurrencyError represents a concurrency limit error with context
...
...
@@ -63,14 +66,36 @@ func (e *ConcurrencyError) Error() string {
type
ConcurrencyHelper
struct
{
concurrencyService
*
service
.
ConcurrencyService
pingFormat
SSEPingFormat
pingInterval
time
.
Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func
NewConcurrencyHelper
(
concurrencyService
*
service
.
ConcurrencyService
,
pingFormat
SSEPingFormat
)
*
ConcurrencyHelper
{
func
NewConcurrencyHelper
(
concurrencyService
*
service
.
ConcurrencyService
,
pingFormat
SSEPingFormat
,
pingInterval
time
.
Duration
)
*
ConcurrencyHelper
{
if
pingInterval
<=
0
{
pingInterval
=
defaultPingInterval
}
return
&
ConcurrencyHelper
{
concurrencyService
:
concurrencyService
,
pingFormat
:
pingFormat
,
pingInterval
:
pingInterval
,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func
wrapReleaseOnDone
(
ctx
context
.
Context
,
releaseFunc
func
())
func
()
{
if
releaseFunc
==
nil
{
return
nil
}
var
once
sync
.
Once
wrapped
:=
func
()
{
once
.
Do
(
releaseFunc
)
}
go
func
()
{
<-
ctx
.
Done
()
wrapped
()
}()
return
wrapped
}
// IncrementWaitCount increments the wait count for a user
...
...
@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed
var
pingCh
<-
chan
time
.
Time
if
needPing
{
pingTicker
:=
time
.
NewTicker
(
pingInterval
)
pingTicker
:=
time
.
NewTicker
(
h
.
pingInterval
)
defer
pingTicker
.
Stop
()
pingCh
=
pingTicker
.
C
}
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
46dda583
...
...
@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription
,
_
:=
middleware
.
GetSubscriptionFromContext
(
c
)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency
:=
NewConcurrencyHelper
(
h
.
concurrencyHelper
.
concurrencyService
,
SSEPingFormatNone
)
geminiConcurrency
:=
NewConcurrencyHelper
(
h
.
concurrencyHelper
.
concurrencyService
,
SSEPingFormatNone
,
0
)
// 0) wait queue check
maxWait
:=
service
.
CalculateMaxWait
(
authSubject
.
Concurrency
)
...
...
@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
defer
userReleaseFunc
()
}
// 2) billing eligibility check (after wait)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
googleError
(
c
,
http
.
StatusForbidden
,
err
.
Error
())
status
,
_
,
message
:=
billingErrorDetails
(
err
)
googleError
(
c
,
status
,
message
)
return
}
...
...
@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
...
...
@@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
}
for
k
,
vv
:=
range
res
.
Headers
{
// Avoid overriding content-length and hop-by-hop headers.
if
strings
.
EqualFold
(
k
,
"Content-Length"
)
||
strings
.
EqualFold
(
k
,
"Transfer-Encoding"
)
||
strings
.
EqualFold
(
k
,
"Connection"
)
{
if
strings
.
EqualFold
(
k
,
"Content-Length"
)
||
strings
.
EqualFold
(
k
,
"Transfer-Encoding"
)
||
strings
.
EqualFold
(
k
,
"Connection"
)
||
strings
.
EqualFold
(
k
,
"Www-Authenticate"
)
{
continue
}
for
_
,
v
:=
range
vv
{
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
46dda583
...
...
@@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService
*
service
.
OpenAIGatewayService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
cfg
*
config
.
Config
,
)
*
OpenAIGatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
if
cfg
!=
nil
{
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
}
return
&
OpenAIGatewayHandler
{
gatewayService
:
gatewayService
,
billingCacheService
:
billingCacheService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormat
None
),
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormat
Comment
,
pingInterval
),
}
}
...
...
@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
defer
userReleaseFunc
()
}
...
...
@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
log
.
Printf
(
"Billing eligibility check failed after wait: %v"
,
err
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusForbidden
,
"billing_error"
,
err
.
Error
(),
streamStarted
)
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
...
...
@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// Forward request
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
...
...
backend/internal/pkg/httpclient/pool.go
View file @
46dda583
...
...
@@ -25,13 +25,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// Transport 连接池默认配置
const
(
defaultMaxIdleConns
=
100
// 最大空闲连接数
defaultMaxIdleConnsPerHost
=
10
// 每个主机最大空闲连接数
defaultIdleConnTimeout
=
90
*
time
.
Second
// 空闲连接超时时间
defaultIdleConnTimeout
=
90
*
time
.
Second
// 空闲连接超时时间
(建议小于上游 LB 超时)
)
// Options 定义共享 HTTP 客户端的构建参数
...
...
@@ -40,6 +41,9 @@ type Options struct {
Timeout
time
.
Duration
// 请求总超时时间
ResponseHeaderTimeout
time
.
Duration
// 等待响应头超时时间
InsecureSkipVerify
bool
// 是否跳过 TLS 证书验证
ProxyStrict
bool
// 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP
bool
// 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts
bool
// 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns
int
// 最大空闲连接总数(默认 100)
...
...
@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return
nil
,
err
}
var
rt
http
.
RoundTripper
=
transport
if
opts
.
ValidateResolvedIP
&&
!
opts
.
AllowPrivateHosts
{
rt
=
&
validatedTransport
{
base
:
transport
}
}
return
&
http
.
Client
{
Transport
:
transpo
rt
,
Transport
:
rt
,
Timeout
:
opts
.
Timeout
,
},
nil
}
...
...
@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func
buildClientKey
(
opts
Options
)
string
{
return
fmt
.
Sprintf
(
"%s|%s|%s|%t|%d|%d|%d"
,
return
fmt
.
Sprintf
(
"%s|%s|%s|%t|%
t|%t|%t|%
d|%d|%d"
,
strings
.
TrimSpace
(
opts
.
ProxyURL
),
opts
.
Timeout
.
String
(),
opts
.
ResponseHeaderTimeout
.
String
(),
opts
.
InsecureSkipVerify
,
opts
.
ProxyStrict
,
opts
.
ValidateResolvedIP
,
opts
.
AllowPrivateHosts
,
opts
.
MaxIdleConns
,
opts
.
MaxIdleConnsPerHost
,
opts
.
MaxConnsPerHost
,
)
}
type
validatedTransport
struct
{
base
http
.
RoundTripper
}
func
(
t
*
validatedTransport
)
RoundTrip
(
req
*
http
.
Request
)
(
*
http
.
Response
,
error
)
{
if
req
!=
nil
&&
req
.
URL
!=
nil
{
host
:=
strings
.
TrimSpace
(
req
.
URL
.
Hostname
())
if
host
!=
""
{
if
err
:=
urlvalidator
.
ValidateResolvedIP
(
host
);
err
!=
nil
{
return
nil
,
err
}
}
}
return
t
.
base
.
RoundTrip
(
req
)
}
backend/internal/repository/claude_oauth_service.go
View file @
46dda583
...
...
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
...
...
@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return
""
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
log
.
Printf
(
"[OAuth] Step 1 Response - Status: %d
, Body: %s
"
,
resp
.
StatusCode
,
resp
.
String
()
)
log
.
Printf
(
"[OAuth] Step 1 Response - Status: %d"
,
resp
.
StatusCode
)
if
!
resp
.
IsSuccessState
()
{
return
""
,
fmt
.
Errorf
(
"failed to get organizations: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
...
...
@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method"
:
"S256"
,
}
reqBodyJSON
,
_
:=
json
.
Marshal
(
reqBody
)
log
.
Printf
(
"[OAuth] Step 2: Getting authorization code from %s"
,
authURL
)
reqBodyJSON
,
_
:=
json
.
Marshal
(
logredact
.
RedactMap
(
reqBody
))
log
.
Printf
(
"[OAuth] Step 2 Request Body: %s"
,
string
(
reqBodyJSON
))
var
result
struct
{
...
...
@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return
""
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
log
.
Printf
(
"[OAuth] Step 2 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
resp
.
String
(
))
log
.
Printf
(
"[OAuth] Step 2 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
logredact
.
RedactJSON
(
resp
.
Bytes
()
))
if
!
resp
.
IsSuccessState
()
{
return
""
,
fmt
.
Errorf
(
"failed to get authorization code: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
...
...
@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode
=
authCode
+
"#"
+
responseState
}
log
.
Printf
(
"[OAuth] Step 2 SUCCESS - Got authorization code
: %s..."
,
prefix
(
authCode
,
20
)
)
log
.
Printf
(
"[OAuth] Step 2 SUCCESS - Got authorization code
"
)
return
fullCode
,
nil
}
...
...
@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody
[
"expires_in"
]
=
31536000
// 365 * 24 * 60 * 60 seconds
}
reqBodyJSON
,
_
:=
json
.
Marshal
(
reqBody
)
log
.
Printf
(
"[OAuth] Step 3: Exchanging code for token at %s"
,
s
.
tokenURL
)
reqBodyJSON
,
_
:=
json
.
Marshal
(
logredact
.
RedactMap
(
reqBody
))
log
.
Printf
(
"[OAuth] Step 3 Request Body: %s"
,
string
(
reqBodyJSON
))
var
tokenResp
oauth
.
TokenResponse
...
...
@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return
nil
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
log
.
Printf
(
"[OAuth] Step 3 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
resp
.
String
(
))
log
.
Printf
(
"[OAuth] Step 3 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
logredact
.
RedactJSON
(
resp
.
Bytes
()
))
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"token exchange failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
...
...
@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return
client
}
func
prefix
(
s
string
,
n
int
)
string
{
if
n
<=
0
{
return
""
}
if
len
(
s
)
<=
n
{
return
s
}
return
s
[
:
n
]
}
backend/internal/repository/claude_usage_service.go
View file @
46dda583
...
...
@@ -15,7 +15,8 @@ import (
const
defaultClaudeUsageURL
=
"https://api.anthropic.com/api/oauth/usage"
type
claudeUsageService
struct
{
usageURL
string
usageURL
string
allowPrivateHosts
bool
}
func
NewClaudeUsageFetcher
()
service
.
ClaudeUsageFetcher
{
...
...
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func
(
s
*
claudeUsageService
)
FetchUsage
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
*
service
.
ClaudeUsageResponse
,
error
)
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
30
*
time
.
Second
,
ProxyURL
:
proxyURL
,
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
s
.
allowPrivateHosts
,
})
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
backend/internal/repository/claude_usage_service_test.go
View file @
46dda583
...
...
@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`
)
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
,
allowPrivateHosts
:
true
,
}
resp
,
err
:=
s
.
fetcher
.
FetchUsage
(
context
.
Background
(),
"at"
,
"://bad-proxy-url"
)
require
.
NoError
(
s
.
T
(),
err
,
"FetchUsage"
)
...
...
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_
,
_
=
io
.
WriteString
(
w
,
"nope"
)
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
,
allowPrivateHosts
:
true
,
}
_
,
err
:=
s
.
fetcher
.
FetchUsage
(
context
.
Background
(),
"at"
,
""
)
require
.
Error
(
s
.
T
(),
err
)
...
...
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_
,
_
=
io
.
WriteString
(
w
,
"not-json"
)
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
,
allowPrivateHosts
:
true
,
}
_
,
err
:=
s
.
fetcher
.
FetchUsage
(
context
.
Background
(),
"at"
,
""
)
require
.
Error
(
s
.
T
(),
err
)
...
...
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-
r
.
Context
()
.
Done
()
}))
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
}
s
.
fetcher
=
&
claudeUsageService
{
usageURL
:
s
.
srv
.
URL
,
allowPrivateHosts
:
true
,
}
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// Cancel immediately
...
...
backend/internal/repository/github_release_service.go
View file @
46dda583
...
...
@@ -14,18 +14,23 @@ import (
)
type
githubReleaseClient
struct
{
httpClient
*
http
.
Client
httpClient
*
http
.
Client
allowPrivateHosts
bool
}
func
NewGitHubReleaseClient
()
service
.
GitHubReleaseClient
{
allowPrivate
:=
false
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
allowPrivate
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
}
return
&
githubReleaseClient
{
httpClient
:
sharedClient
,
httpClient
:
sharedClient
,
allowPrivateHosts
:
allowPrivate
,
}
}
...
...
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
}
downloadClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
10
*
time
.
Minute
,
Timeout
:
10
*
time
.
Minute
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
c
.
allowPrivateHosts
,
})
if
err
!=
nil
{
downloadClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Minute
}
...
...
backend/internal/repository/github_release_service_test.go
View file @
46dda583
...
...
@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return
http
.
DefaultTransport
.
RoundTrip
(
newReq
)
}
func
newTestGitHubReleaseClient
()
*
githubReleaseClient
{
return
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{},
allowPrivateHosts
:
true
,
}
}
func
(
s
*
GitHubReleaseServiceSuite
)
SetupTest
()
{
s
.
tempDir
=
s
.
T
()
.
TempDir
()
}
...
...
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_
,
_
=
w
.
Write
(
bytes
.
Repeat
([]
byte
(
"a"
),
100
))
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"file1.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
10
)
...
...
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"file2.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
10
)
...
...
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"file3.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
200
)
...
...
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w
.
WriteHeader
(
http
.
StatusNotFound
)
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"notfound.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
s
.
srv
.
URL
,
dest
,
100
)
...
...
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_
,
_
=
w
.
Write
([]
byte
(
"sum"
))
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
body
,
err
:=
s
.
client
.
FetchChecksumFile
(
context
.
Background
(),
s
.
srv
.
URL
)
require
.
NoError
(
s
.
T
(),
err
,
"FetchChecksumFile"
)
...
...
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
_
,
err
:=
s
.
client
.
FetchChecksumFile
(
context
.
Background
(),
s
.
srv
.
URL
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for non-200"
)
...
...
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-
r
.
Context
()
.
Done
()
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
...
...
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestDownloadFile_InvalidURL
()
{
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"invalid.bin"
)
err
:=
s
.
client
.
DownloadFile
(
context
.
Background
(),
"://invalid-url"
,
dest
,
100
)
...
...
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_
,
_
=
w
.
Write
([]
byte
(
"content"
))
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
// Use a path that cannot be created (directory doesn't exist)
dest
:=
filepath
.
Join
(
s
.
tempDir
,
"nonexistent"
,
"subdir"
,
"file.bin"
)
...
...
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
}
func
(
s
*
GitHubReleaseServiceSuite
)
TestFetchChecksumFile_InvalidURL
()
{
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
_
,
err
:=
s
.
client
.
FetchChecksumFile
(
context
.
Background
(),
"://invalid-url"
)
require
.
Error
(
s
.
T
(),
err
,
"expected error for invalid URL"
)
...
...
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
}
release
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
...
...
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
}
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
...
...
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
}
_
,
err
:=
s
.
client
.
FetchLatestRelease
(
context
.
Background
(),
"test/repo"
)
...
...
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient
:
&
http
.
Client
{
Transport
:
&
testTransport
{
testServerURL
:
s
.
srv
.
URL
},
},
allowPrivateHosts
:
true
,
}
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
...
...
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-
r
.
Context
()
.
Done
()
}))
client
,
ok
:=
NewGitHubReleaseClient
()
.
(
*
githubReleaseClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
newTestGitHubReleaseClient
()
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
...
...
backend/internal/repository/http_upstream.go
View file @
46dda583
...
...
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
...
...
@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost
=
240
// defaultIdleConnTimeout: 默认空闲连接超时时间(
5分钟
)
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout
=
30
0
*
time
.
Second
// defaultIdleConnTimeout: 默认空闲连接超时时间(
90秒
)
// 超时后连接会被关闭,释放系统资源
(建议小于上游 LB 超时)
defaultIdleConnTimeout
=
9
0
*
time
.
Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout
=
300
*
time
.
Second
...
...
@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func
(
s
*
httpUpstreamService
)
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
{
if
err
:=
s
.
validateRequestHost
(
req
);
err
!=
nil
{
return
nil
,
err
}
// 获取或创建对应的客户端,并标记请求占用
entry
,
err
:=
s
.
acquireClient
(
proxyURL
,
accountID
,
accountConcurrency
)
if
err
!=
nil
{
...
...
@@ -145,6 +150,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return
resp
,
nil
}
func
(
s
*
httpUpstreamService
)
shouldValidateResolvedIP
()
bool
{
if
s
.
cfg
==
nil
{
return
false
}
return
!
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
}
func
(
s
*
httpUpstreamService
)
validateRequestHost
(
req
*
http
.
Request
)
error
{
if
!
s
.
shouldValidateResolvedIP
()
{
return
nil
}
if
req
==
nil
||
req
.
URL
==
nil
{
return
errors
.
New
(
"request url is nil"
)
}
host
:=
strings
.
TrimSpace
(
req
.
URL
.
Hostname
())
if
host
==
""
{
return
errors
.
New
(
"request host is empty"
)
}
if
err
:=
urlvalidator
.
ValidateResolvedIP
(
host
);
err
!=
nil
{
return
err
}
return
nil
}
func
(
s
*
httpUpstreamService
)
redirectChecker
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
10
{
return
errors
.
New
(
"stopped after 10 redirects"
)
}
return
s
.
validateRequestHost
(
req
)
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func
(
s
*
httpUpstreamService
)
acquireClient
(
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
upstreamClientEntry
,
error
)
{
...
...
@@ -232,6 +268,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return
nil
,
fmt
.
Errorf
(
"build transport: %w"
,
err
)
}
client
:=
&
http
.
Client
{
Transport
:
transport
}
if
s
.
shouldValidateResolvedIP
()
{
client
.
CheckRedirect
=
s
.
redirectChecker
}
entry
:=
&
upstreamClientEntry
{
client
:
client
,
proxyKey
:
proxyKey
,
...
...
backend/internal/repository/http_upstream_test.go
View file @
46dda583
...
...
@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func
(
s
*
HTTPUpstreamSuite
)
SetupTest
()
{
s
.
cfg
=
&
config
.
Config
{}
s
.
cfg
=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
AllowPrivateHosts
:
true
,
},
},
}
}
// newService 创建测试用的 httpUpstreamService 实例
...
...
backend/internal/repository/pricing_service.go
View file @
46dda583
...
...
@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
...
...
@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
httpClient
*
http
.
Client
}
func
NewPricingRemoteClient
()
service
.
PricingRemoteClient
{
func
NewPricingRemoteClient
(
cfg
*
config
.
Config
)
service
.
PricingRemoteClient
{
allowPrivate
:=
false
if
cfg
!=
nil
{
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
}
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
allowPrivate
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment