Commit aa6f2533 authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并 upstream/main 并解决冲突

解决了以下文件的冲突:
- backend/internal/handler/admin/setting_handler.go
  - 采用 upstream 的字段对齐风格和 *Configured 字段名
  - 添加 EnableIdentityPatch 和 IdentityPatchPrompt 字段

- backend/internal/handler/gateway_handler.go
  - 采用 upstream 的 billingErrorDetails 错误处理方式

- frontend/src/api/admin/settings.ts
  - 采用 upstream 的 *_configured 字段名
  - 添加 enable_identity_patch 和 identity_patch_prompt 字段

- frontend/src/views/admin/SettingsView.vue
  - 合并 turnstile_secret_key_configured 字段
  - 保留 enable_identity_patch 和 identity_patch_prompt 字段
parents f60f943d 46dda583
...@@ -268,6 +268,15 @@ default: ...@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0 rate_multiplier: 1.0
``` ```
Additional security-related options are available in `config.yaml`:
- `cors.allowed_origins` for CORS allowlist
- `security.url_allowlist` for upstream/pricing/CRS host allowlists
- `security.csp` to control Content-Security-Policy headers
- `billing.circuit_breaker` to fail closed on billing errors
- `server.trusted_proxies` to enable X-Forwarded-For parsing
- `turnstile.required` to require Turnstile in release mode
```bash ```bash
# 6. Run the application # 6. Run the application
./sub2api ./sub2api
......
...@@ -268,6 +268,15 @@ default: ...@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0 rate_multiplier: 1.0
``` ```
`config.yaml` 还支持以下安全相关配置:
- `cors.allowed_origins` 配置 CORS 白名单
- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
- `security.csp` 配置 Content-Security-Policy
- `billing.circuit_breaker` 计费异常时 fail-closed
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
- `turnstile.required` 在 release 模式强制启用 Turnstile
```bash ```bash
# 6. 运行应用 # 6. 运行应用
./sub2api ./sub2api
......
...@@ -86,7 +86,8 @@ func main() { ...@@ -86,7 +86,8 @@ func main() {
func runSetupServer() { func runSetupServer() {
r := gin.New() r := gin.New()
r.Use(middleware.Recovery()) r.Use(middleware.Recovery())
r.Use(middleware.CORS()) r.Use(middleware.CORS(config.CORSConfig{}))
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
// Register setup routes // Register setup routes
setup.RegisterRoutes(r) setup.RegisterRoutes(r)
......
...@@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService) dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db) accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber() proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
...@@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig) httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
...@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.NewPricingRemoteClient() pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService() timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
package config package config
import ( import (
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"log"
"strings" "strings"
"time" "time"
...@@ -14,6 +17,8 @@ const ( ...@@ -14,6 +17,8 @@ const (
RunModeSimple = "simple" RunModeSimple = "simple"
) )
const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量 // 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 // 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const ( const (
...@@ -30,6 +35,10 @@ const ( ...@@ -30,6 +35,10 @@ const (
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"` Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"` JWT JWTConfig `mapstructure:"jwt"`
...@@ -37,6 +46,7 @@ type Config struct { ...@@ -37,6 +46,7 @@ type Config struct {
RateLimit RateLimitConfig `mapstructure:"rate_limit"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"` Gateway GatewayConfig `mapstructure:"gateway"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
...@@ -95,11 +105,61 @@ type PricingConfig struct { ...@@ -95,11 +105,61 @@ type PricingConfig struct {
} }
type ServerConfig struct { type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
}
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowCredentials bool `mapstructure:"allow_credentials"`
}
type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
}
type URLAllowlistConfig struct {
UpstreamHosts []string `mapstructure:"upstream_hosts"`
PricingHosts []string `mapstructure:"pricing_hosts"`
CRSHosts []string `mapstructure:"crs_hosts"`
AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
}
type ResponseHeaderConfig struct {
AdditionalAllowed []string `mapstructure:"additional_allowed"`
ForceRemove []string `mapstructure:"force_remove"`
}
type CSPConfig struct {
Enabled bool `mapstructure:"enabled"`
Policy string `mapstructure:"policy"`
}
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
}
type BillingConfig struct {
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
}
type CircuitBreakerConfig struct {
Enabled bool `mapstructure:"enabled"`
FailureThreshold int `mapstructure:"failure_threshold"`
ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
HalfOpenRequests int `mapstructure:"half_open_requests"`
}
type ConcurrencyConfig struct {
// PingInterval: 并发等待期间的 SSE ping 间隔(秒)
PingInterval int `mapstructure:"ping_interval"`
} }
// GatewayConfig API网关相关配置 // GatewayConfig API网关相关配置
...@@ -134,6 +194,13 @@ type GatewayConfig struct { ...@@ -134,6 +194,13 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期 // 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
MaxLineSize int `mapstructure:"max_line_size"`
// 是否记录上游错误响应体摘要(避免输出请求内容) // 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断) // 上游错误响应体记录最大字节数(超过会截断)
...@@ -237,6 +304,10 @@ type JWTConfig struct { ...@@ -237,6 +304,10 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"` ExpireHour int `mapstructure:"expire_hour"`
} }
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
type DefaultConfig struct { type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"` AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"` AdminPassword string `mapstructure:"admin_password"`
...@@ -287,11 +358,39 @@ func Load() (*Config, error) { ...@@ -287,11 +358,39 @@ func Load() (*Config, error) {
} }
cfg.RunMode = NormalizeRunMode(cfg.RunMode) cfg.RunMode = NormalizeRunMode(cfg.RunMode)
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug"
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
secret, err := generateJWTSecret(64)
if err != nil {
return nil, fmt.Errorf("generate jwt secret error: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
}
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err) return nil, fmt.Errorf("validate config error: %w", err)
} }
if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove,
)
}
return &cfg, nil return &cfg, nil
} }
...@@ -304,6 +403,39 @@ func setDefaults() { ...@@ -304,6 +403,39 @@ func setDefaults() {
viper.SetDefault("server.mode", "debug") viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
// Security
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
"api.openai.com",
"api.anthropic.com",
"generativelanguage.googleapis.com",
"cloudcode-pa.googleapis.com",
"*.openai.azure.com",
})
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
"raw.githubusercontent.com",
})
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
// Turnstile
viper.SetDefault("turnstile.required", false)
// Database // Database
viper.SetDefault("database.host", "localhost") viper.SetDefault("database.host", "localhost")
...@@ -329,7 +461,7 @@ func setDefaults() { ...@@ -329,7 +461,7 @@ func setDefaults() {
viper.SetDefault("redis.min_idle_conns", 10) viper.SetDefault("redis.min_idle_conns", 10)
// JWT // JWT
viper.SetDefault("jwt.secret", "change-me-in-production") viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24) viper.SetDefault("jwt.expire_hour", 24)
// Default // Default
...@@ -357,7 +489,7 @@ func setDefaults() { ...@@ -357,7 +489,7 @@ func setDefaults() {
viper.SetDefault("timezone", "Asia/Shanghai") viper.SetDefault("timezone", "Asia/Shanghai")
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false) viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false) viper.SetDefault("gateway.inject_beta_for_apikey", false)
...@@ -365,19 +497,23 @@ func setDefaults() { ...@@ -365,19 +497,23 @@ func setDefaults() {
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
...@@ -396,11 +532,39 @@ func setDefaults() { ...@@ -396,11 +532,39 @@ func setDefaults() {
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
if c.JWT.Secret == "" { if c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret is required") if c.JWT.Secret == "" {
return fmt.Errorf("jwt.secret is required in release mode")
}
if len(c.JWT.Secret) < 32 {
return fmt.Errorf("jwt.secret must be at least 32 characters")
}
if isWeakJWTSecret(c.JWT.Secret) {
return fmt.Errorf("jwt.secret is too weak")
}
}
if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive")
}
if c.JWT.ExpireHour > 168 {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
} }
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { if c.JWT.ExpireHour > 24 {
return fmt.Errorf("jwt.secret must be changed in production") log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
}
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
}
if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
}
if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
}
} }
if c.Database.MaxOpenConns <= 0 { if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive") return fmt.Errorf("database.max_open_conns must be positive")
...@@ -458,6 +622,9 @@ func (c *Config) Validate() error { ...@@ -458,6 +622,9 @@ func (c *Config) Validate() error {
if c.Gateway.IdleConnTimeoutSeconds <= 0 { if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
} }
if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
}
if c.Gateway.MaxUpstreamClients <= 0 { if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive") return fmt.Errorf("gateway.max_upstream_clients must be positive")
} }
...@@ -467,6 +634,26 @@ func (c *Config) Validate() error { ...@@ -467,6 +634,26 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
} }
if c.Gateway.StreamDataIntervalTimeout < 0 {
return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative")
}
if c.Gateway.StreamDataIntervalTimeout != 0 &&
(c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) {
return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds")
}
if c.Gateway.StreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative")
}
if c.Gateway.StreamKeepaliveInterval != 0 &&
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
} }
...@@ -482,9 +669,57 @@ func (c *Config) Validate() error { ...@@ -482,9 +669,57 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.SlotCleanupInterval < 0 { if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
} }
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
return nil return nil
} }
func normalizeStringSlice(values []string) []string {
if len(values) == 0 {
return values
}
normalized := make([]string, 0, len(values))
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}
func isWeakJWTSecret(secret string) bool {
lower := strings.ToLower(strings.TrimSpace(secret))
if lower == "" {
return true
}
weak := map[string]struct{}{
"change-me-in-production": {},
"changeme": {},
"secret": {},
"password": {},
"123456": {},
"12345678": {},
"admin": {},
"jwt-secret": {},
}
_, exists := weak[lower]
return exists
}
func generateJWTSecret(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 32
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable. // GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation, // This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup. // such as during setup wizard startup.
......
package admin package admin
import ( import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -34,33 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -34,33 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
} }
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost, SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort, SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername, SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword, SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom, SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName, SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS, SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey, TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName, SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo, SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle, SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL, APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback, EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI, FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini, FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity, FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch, EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt, IdentityPatchPrompt: settings.IdentityPatchPrompt,
}) })
} }
...@@ -117,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -117,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数 // 验证参数
if req.DefaultConcurrency < 1 { if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1 req.DefaultConcurrency = 1
...@@ -193,6 +203,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -193,6 +203,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
h.auditSettingsUpdate(c, previousSettings, settings, req)
// 重新获取设置返回 // 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil { if err != nil {
...@@ -201,36 +213,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -201,36 +213,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost, SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort, SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername, SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword, SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom, SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName, SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS, SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled, TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey, TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey, TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName, SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo, SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle, SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL, APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo, ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL, DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback, EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini, FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch, EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
}) })
} }
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
}
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
if before.SMTPPort != after.SMTPPort {
changed = append(changed, "smtp_port")
}
if before.SMTPUsername != after.SMTPUsername {
changed = append(changed, "smtp_username")
}
if req.SMTPPassword != "" {
changed = append(changed, "smtp_password")
}
if before.SMTPFrom != after.SMTPFrom {
changed = append(changed, "smtp_from_email")
}
if before.SMTPFromName != after.SMTPFromName {
changed = append(changed, "smtp_from_name")
}
if before.SMTPUseTLS != after.SMTPUseTLS {
changed = append(changed, "smtp_use_tls")
}
if before.TurnstileEnabled != after.TurnstileEnabled {
changed = append(changed, "turnstile_enabled")
}
if before.TurnstileSiteKey != after.TurnstileSiteKey {
changed = append(changed, "turnstile_site_key")
}
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
if before.SiteLogo != after.SiteLogo {
changed = append(changed, "site_logo")
}
if before.SiteSubtitle != after.SiteSubtitle {
changed = append(changed, "site_subtitle")
}
if before.APIBaseURL != after.APIBaseURL {
changed = append(changed, "api_base_url")
}
if before.ContactInfo != after.ContactInfo {
changed = append(changed, "contact_info")
}
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
changed = append(changed, "fallback_model_anthropic")
}
if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
changed = append(changed, "fallback_model_openai")
}
if before.FallbackModelGemini != after.FallbackModelGemini {
changed = append(changed, "fallback_model_gemini")
}
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
return changed
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host" binding:"required"`
......
...@@ -5,17 +5,17 @@ type SystemSettings struct { ...@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"` SMTPPasswordConfigured bool `json:"smtp_password_configured"`
SMTPFrom string `json:"smtp_from_email"` SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"` SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"` SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
......
This diff is collapsed.
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -26,8 +27,8 @@ import ( ...@@ -26,8 +27,8 @@ import (
const ( const (
// maxConcurrencyWait 等待并发槽位的最大时间 // maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second maxConcurrencyWait = 30 * time.Second
// pingInterval 流式响应等待时发送 ping 的间隔 // defaultPingInterval 流式响应等待时发送 ping 的默认间隔
pingInterval = 15 * time.Second defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间 // initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避) // backoffMultiplier 退避时间乘数(指数退避)
...@@ -44,6 +45,8 @@ const ( ...@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n" SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec) // SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = "" SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
) )
// ConcurrencyError represents a concurrency limit error with context // ConcurrencyError represents a concurrency limit error with context
...@@ -63,14 +66,36 @@ func (e *ConcurrencyError) Error() string { ...@@ -63,14 +66,36 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct { type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat pingFormat SSEPingFormat
pingInterval time.Duration
} }
// NewConcurrencyHelper creates a new ConcurrencyHelper // NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper { func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{ return &ConcurrencyHelper{
concurrencyService: concurrencyService, concurrencyService: concurrencyService,
pingFormat: pingFormat, pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
} }
go func() {
<-ctx.Done()
wrapped()
}()
return wrapped
} }
// IncrementWaitCount increments the wait count for a user // IncrementWaitCount increments the wait count for a user
...@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed // Only create ping ticker if ping is needed
var pingCh <-chan time.Time var pingCh <-chan time.Time
if needPing { if needPing {
pingTicker := time.NewTicker(pingInterval) pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop() defer pingTicker.Stop()
pingCh = pingTicker.C pingCh = pingTicker.C
} }
......
...@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames. // For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check // 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency) maxWait := service.CalculateMaxWait(authSubject.Concurrency)
...@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil { if userReleaseFunc != nil {
defer userReleaseFunc() defer userReleaseFunc()
} }
// 2) billing eligibility check (after wait) // 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error()) status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return return
} }
...@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
var result *service.ForwardResult var result *service.ForwardResult
...@@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) { ...@@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
} }
for k, vv := range res.Headers { for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers. // Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") { if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
continue continue
} }
for _, v := range vv { for _, v := range vv {
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler( ...@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &OpenAIGatewayHandler{ return &OpenAIGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
} }
} }
...@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
return return
} }
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil { if userReleaseFunc != nil {
defer userReleaseFunc() defer userReleaseFunc()
} }
...@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
...@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request // Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
......
...@@ -25,13 +25,14 @@ import ( ...@@ -25,13 +25,14 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
// Transport 连接池默认配置 // Transport 连接池默认配置
const ( const (
defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
) )
// Options 定义共享 HTTP 客户端的构建参数 // Options 定义共享 HTTP 客户端的构建参数
...@@ -40,6 +41,9 @@ type Options struct { ...@@ -40,6 +41,9 @@ type Options struct {
Timeout time.Duration // 请求总超时时间 Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证 InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值) // 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100) MaxIdleConns int // 最大空闲连接总数(默认 100)
...@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) { ...@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err return nil, err
} }
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
}
return &http.Client{ return &http.Client{
Transport: transport, Transport: rt,
Timeout: opts.Timeout, Timeout: opts.Timeout,
}, nil }, nil
} }
...@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) { ...@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
} }
func buildClientKey(opts Options) string { func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d", return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL), strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(), opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(), opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify, opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns, opts.MaxIdleConns,
opts.MaxIdleConnsPerHost, opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost, opts.MaxConnsPerHost,
) )
} }
type validatedTransport struct {
base http.RoundTripper
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
}
}
}
return t.base.RoundTrip(req)
}
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
...@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256", "code_challenge_method": "S256",
} }
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct { var result struct {
...@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState fullCode = authCode + "#" + responseState
} }
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil return fullCode, nil
} }
...@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
} }
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse var tokenResp oauth.TokenResponse
...@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client { ...@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client return client
} }
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}
...@@ -15,7 +15,8 @@ import ( ...@@ -15,7 +15,8 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage" const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct { type claudeUsageService struct {
usageURL string usageURL string
allowPrivateHosts bool
} }
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
...@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { ...@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 30 * time.Second} client = &http.Client{Timeout: 30 * time.Second}
......
...@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { ...@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`) }`)
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage") require.NoError(s.T(), err, "FetchUsage")
...@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { ...@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope") _, _ = io.WriteString(w, "nope")
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "") _, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err) require.Error(s.T(), err)
...@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { ...@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json") _, _ = io.WriteString(w, "not-json")
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "") _, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err) require.Error(s.T(), err)
...@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { ...@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done() <-r.Context().Done()
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately cancel() // Cancel immediately
......
...@@ -14,18 +14,23 @@ import ( ...@@ -14,18 +14,23 @@ import (
) )
type githubReleaseClient struct { type githubReleaseClient struct {
httpClient *http.Client httpClient *http.Client
allowPrivateHosts bool
} }
func NewGitHubReleaseClient() service.GitHubReleaseClient { func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
} }
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: sharedClient, httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
} }
} }
...@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
} }
downloadClient, err := httpclient.GetClient(httpclient.Options{ downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute, Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
}) })
if err != nil { if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute} downloadClient = &http.Client{Timeout: 10 * time.Minute}
......
...@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq) return http.DefaultTransport.RoundTrip(newReq)
} }
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() { func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir() s.tempDir = s.T().TempDir()
} }
...@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng ...@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100)) _, _ = w.Write(bytes.Repeat([]byte("a"), 100))
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file1.bin") dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
...@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { ...@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
} }
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file2.bin") dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
...@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { ...@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
} }
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file3.bin") dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
...@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { ...@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "notfound.bin") dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
...@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { ...@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum")) _, _ = w.Write([]byte("sum"))
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile") require.NoError(s.T(), err, "FetchChecksumFile")
...@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { ...@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) _, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200") require.Error(s.T(), err, "expected error for non-200")
...@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { ...@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done() <-r.Context().Done()
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
...@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { ...@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "invalid.bin") dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100) err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
...@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { ...@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content")) _, _ = w.Write([]byte("content"))
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
// Use a path that cannot be created (directory doesn't exist) // Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin") dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
...@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { ...@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url") _, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL") require.Error(s.T(), err, "expected error for invalid URL")
...@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ...@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo") release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { ...@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { ...@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { ...@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
...@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { ...@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done() <-r.Context().Done()
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
// 默认配置常量 // 默认配置常量
...@@ -30,9 +31,9 @@ const ( ...@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接 // 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240 defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟 // defaultIdleConnTimeout: 默认空闲连接超时时间(90秒
// 超时后连接会被关闭,释放系统资源 // 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
defaultIdleConnTimeout = 300 * time.Second defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时 // LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second defaultResponseHeaderTimeout = 300 * time.Second
...@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { ...@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏 // - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 // - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用 // 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil { if err != nil {
...@@ -145,6 +150,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i ...@@ -145,6 +150,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil return resp, nil
} }
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求 // acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰 // 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
...@@ -232,6 +268,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a ...@@ -232,6 +268,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return nil, fmt.Errorf("build transport: %w", err) return nil, fmt.Errorf("build transport: %w", err)
} }
client := &http.Client{Transport: transport} client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{ entry := &upstreamClientEntry{
client: client, client: client,
proxyKey: proxyKey, proxyKey: proxyKey,
......
...@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct { ...@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化 // SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖 // 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() { func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{} s.cfg = &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}
} }
// newService 创建测试用的 httpUpstreamService 实例 // newService 创建测试用的 httpUpstreamService 实例
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -16,9 +17,15 @@ type pricingRemoteClient struct { ...@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
httpClient *http.Client httpClient *http.Client
} }
func NewPricingRemoteClient() service.PricingRemoteClient { func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
}
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment