"...internal/pkg/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "ece0606feddd1dbfdcd779cbf2dcaffba17a35c4"
Commit fd43be8d authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突



- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parents 792bef61 836ba14b
# 确保所有 SQL 迁移文件使用 LF 换行符
backend/migrations/*.sql text eol=lf
# Go 源代码文件
*.go text eol=lf
# Shell 脚本
*.sh text eol=lf
# YAML/YML 配置文件
*.yaml text eol=lf
*.yml text eol=lf
# Dockerfile
Dockerfile text eol=lf
...@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig) secretEncryptor, err := repository.NewAESEncryptor(configConfig)
...@@ -128,7 +128,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -128,7 +128,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
...@@ -144,8 +146,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -144,8 +146,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService) promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db) opsRepository := repository.NewOpsRepository(db)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil { if err != nil {
...@@ -159,7 +159,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -159,7 +159,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService) opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient) updateCache := repository.NewUpdateCache(redisClient)
......
...@@ -38,32 +38,33 @@ const ( ...@@ -38,32 +38,33 @@ const (
) )
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"` CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"` Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"` Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"` Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"` Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"` Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"` JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"` Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"` Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"` Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"` UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
Sora SoraConfig `mapstructure:"sora"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Sora SoraConfig `mapstructure:"sora"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Gemini GeminiConfig `mapstructure:"gemini"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Update UpdateConfig `mapstructure:"update"` Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
} }
type GeminiConfig struct { type GeminiConfig struct {
...@@ -148,6 +149,7 @@ type ServerConfig struct { ...@@ -148,6 +149,7 @@ type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
...@@ -267,6 +269,9 @@ type GatewayConfig struct { ...@@ -267,6 +269,9 @@ type GatewayConfig struct {
MaxBodySize int64 `mapstructure:"max_body_size"` MaxBodySize int64 `mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优) // HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数 // MaxIdleConns: 所有主机的最大空闲连接总数
...@@ -590,6 +595,13 @@ type APIKeyAuthCacheConfig struct { ...@@ -590,6 +595,13 @@ type APIKeyAuthCacheConfig struct {
Singleflight bool `mapstructure:"singleflight"` Singleflight bool `mapstructure:"singleflight"`
} }
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// DashboardCacheConfig 仪表盘统计缓存配置 // DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct { type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存 // Enabled: 是否启用仪表盘缓存
...@@ -695,6 +707,7 @@ func Load() (*Config, error) { ...@@ -695,6 +707,7 @@ func Load() (*Config, error) {
if cfg.Server.Mode == "" { if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug" cfg.Server.Mode = "debug"
} }
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
...@@ -767,7 +780,8 @@ func setDefaults() { ...@@ -767,7 +780,8 @@ func setDefaults() {
// Server // Server
viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080) viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug") viper.SetDefault("server.mode", "release")
viper.SetDefault("server.frontend_url", "")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{}) viper.SetDefault("server.trusted_proxies", []string{})
...@@ -802,7 +816,7 @@ func setDefaults() { ...@@ -802,7 +816,7 @@ func setDefaults() {
viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", false) viper.SetDefault("security.response_headers.enabled", true)
viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true) viper.SetDefault("security.csp.enabled", true)
...@@ -840,9 +854,9 @@ func setDefaults() { ...@@ -840,9 +854,9 @@ func setDefaults() {
viper.SetDefault("database.user", "postgres") viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres") viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable") viper.SetDefault("database.sslmode", "prefer")
viper.SetDefault("database.max_open_conns", 50) viper.SetDefault("database.max_open_conns", 256)
viper.SetDefault("database.max_idle_conns", 10) viper.SetDefault("database.max_idle_conns", 128)
viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5) viper.SetDefault("database.conn_max_idle_time_minutes", 5)
...@@ -854,8 +868,8 @@ func setDefaults() { ...@@ -854,8 +868,8 @@ func setDefaults() {
viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128) viper.SetDefault("redis.pool_size", 1024)
viper.SetDefault("redis.min_idle_conns", 10) viper.SetDefault("redis.min_idle_conns", 128)
viper.SetDefault("redis.enable_tls", false) viper.SetDefault("redis.enable_tls", false)
// Ops (vNext) // Ops (vNext)
...@@ -914,6 +928,11 @@ func setDefaults() { ...@@ -914,6 +928,11 @@ func setDefaults() {
viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true) viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache // Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
...@@ -947,6 +966,7 @@ func setDefaults() { ...@@ -947,6 +966,7 @@ func setDefaults() {
viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
...@@ -958,9 +978,9 @@ func setDefaults() { ...@@ -958,9 +978,9 @@ func setDefaults() {
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认 viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认 viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
...@@ -1030,6 +1050,22 @@ func setDefaults() { ...@@ -1030,6 +1050,22 @@ func setDefaults() {
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
if strings.TrimSpace(c.Server.FrontendURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
if err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
if u.RawQuery != "" || u.ForceQuery {
return fmt.Errorf("server.frontend_url invalid: must not include query")
}
if u.User != nil {
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
}
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
}
if c.JWT.ExpireHour <= 0 { if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive") return fmt.Errorf("jwt.expire_hour must be positive")
} }
......
...@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { ...@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if !cfg.Security.URLAllowlist.AllowPrivateHosts { if !cfg.Security.URLAllowlist.AllowPrivateHosts {
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
} }
if cfg.Security.ResponseHeaders.Enabled { if !cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false") t.Fatalf("ResponseHeaders.Enabled = false, want true")
}
}
func TestLoadDefaultServerMode(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Server.Mode != "release" {
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
}
}
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Database.SSLMode != "prefer" {
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
} }
} }
...@@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { ...@@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
} }
} }
func TestValidateServerFrontendURL(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com/path"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com?utm=1"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with query")
}
cfg.Server.FrontendURL = "https://user:pass@example.com"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
}
cfg.Server.FrontendURL = "/relative"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject relative server.frontend_url")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) { func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
......
...@@ -65,3 +65,38 @@ const ( ...@@ -65,3 +65,38 @@ const (
SubscriptionStatusExpired = "expired" SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended" SubscriptionStatusSuspended = "suspended"
) )
// DefaultAntigravityModelMapping 是 Antigravity 平台的默认模型映射
// 当账号未配置 model_mapping 时使用此默认值
// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致
var DefaultAntigravityModelMapping = map[string]string{
// Claude 白名单
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
// Claude 详细版本 ID 映射
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
// Claude Haiku → Sonnet(无 Haiku 支持)
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
// Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
// Gemini 3 preview 映射
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
package admin
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
dataType = "sub2api-data"
legacyDataType = "sub2api-bundle"
dataVersion = 1
dataPageCap = 1000
)
type DataPayload struct {
Type string `json:"type,omitempty"`
Version int `json:"version,omitempty"`
ExportedAt string `json:"exported_at"`
Proxies []DataProxy `json:"proxies"`
Accounts []DataAccount `json:"accounts"`
}
type DataProxy struct {
ProxyKey string `json:"proxy_key"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Status string `json:"status"`
}
type DataAccount struct {
Name string `json:"name"`
Notes *string `json:"notes,omitempty"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra,omitempty"`
ProxyKey *string `json:"proxy_key,omitempty"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"`
}
type DataImportRequest struct {
Data DataPayload `json:"data"`
SkipDefaultGroupBind *bool `json:"skip_default_group_bind"`
}
type DataImportResult struct {
ProxyCreated int `json:"proxy_created"`
ProxyReused int `json:"proxy_reused"`
ProxyFailed int `json:"proxy_failed"`
AccountCreated int `json:"account_created"`
AccountFailed int `json:"account_failed"`
Errors []DataImportError `json:"errors,omitempty"`
}
type DataImportError struct {
Kind string `json:"kind"`
Name string `json:"name,omitempty"`
ProxyKey string `json:"proxy_key,omitempty"`
Message string `json:"message"`
}
func buildProxyKey(protocol, host string, port int, username, password string) string {
return fmt.Sprintf("%s|%s|%d|%s|%s", strings.TrimSpace(protocol), strings.TrimSpace(host), port, strings.TrimSpace(username), strings.TrimSpace(password))
}
func (h *AccountHandler) ExportData(c *gin.Context) {
ctx := c.Request.Context()
selectedIDs, err := parseAccountIDs(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
accounts, err := h.resolveExportAccounts(ctx, selectedIDs, c)
if err != nil {
response.ErrorFrom(c, err)
return
}
includeProxies, err := parseIncludeProxies(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var proxies []service.Proxy
if includeProxies {
proxies, err = h.resolveExportProxies(ctx, accounts)
if err != nil {
response.ErrorFrom(c, err)
return
}
} else {
proxies = []service.Proxy{}
}
proxyKeyByID := make(map[int64]string, len(proxies))
dataProxies := make([]DataProxy, 0, len(proxies))
for i := range proxies {
p := proxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyKeyByID[p.ID] = key
dataProxies = append(dataProxies, DataProxy{
ProxyKey: key,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
})
}
dataAccounts := make([]DataAccount, 0, len(accounts))
for i := range accounts {
acc := accounts[i]
var proxyKey *string
if acc.ProxyID != nil {
if key, ok := proxyKeyByID[*acc.ProxyID]; ok {
proxyKey = &key
}
}
var expiresAt *int64
if acc.ExpiresAt != nil {
v := acc.ExpiresAt.Unix()
expiresAt = &v
}
dataAccounts = append(dataAccounts, DataAccount{
Name: acc.Name,
Notes: acc.Notes,
Platform: acc.Platform,
Type: acc.Type,
Credentials: acc.Credentials,
Extra: acc.Extra,
ProxyKey: proxyKey,
Concurrency: acc.Concurrency,
Priority: acc.Priority,
RateMultiplier: acc.RateMultiplier,
ExpiresAt: expiresAt,
AutoPauseOnExpired: &acc.AutoPauseOnExpired,
})
}
payload := DataPayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Proxies: dataProxies,
Accounts: dataAccounts,
}
response.Success(c, payload)
}
func (h *AccountHandler) ImportData(c *gin.Context) {
var req DataImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
dataPayload := req.Data
if err := validateDataHeader(dataPayload); err != nil {
response.BadRequest(c, err.Error())
return
}
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
for i := range existingProxies {
p := existingProxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyKeyToID[key] = p.ID
}
for i := range dataPayload.Proxies {
item := dataPayload.Proxies[i]
key := item.ProxyKey
if key == "" {
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
}
if err := validateDataProxy(item); err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
normalizedStatus := normalizeProxyStatus(item.Status)
if existingID, ok := proxyKeyToID[key]; ok {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
}
continue
}
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
proxyKeyToID[key] = created.ID
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
}
for i := range dataPayload.Accounts {
item := dataPayload.Accounts[i]
if err := validateDataAccount(item); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
Message: err.Error(),
})
continue
}
var proxyID *int64
if item.ProxyKey != nil && *item.ProxyKey != "" {
if id, ok := proxyKeyToID[*item.ProxyKey]; ok {
proxyID = &id
} else {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
ProxyKey: *item.ProxyKey,
Message: "proxy_key not found",
})
continue
}
}
accountInput := &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: proxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: nil,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
Message: err.Error(),
})
continue
}
result.AccountCreated++
}
response.Success(c, result)
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "")
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) {
page := 1
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, c *gin.Context) ([]service.Account, error) {
if len(ids) > 0 {
accounts, err := h.adminService.GetAccountsByIDs(ctx, ids)
if err != nil {
return nil, err
}
out := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
if acc == nil {
continue
}
out = append(out, *acc)
}
return out, nil
}
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
return h.listAccountsFiltered(ctx, platform, accountType, status, search)
}
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {
if len(accounts) == 0 {
return []service.Proxy{}, nil
}
seen := make(map[int64]struct{})
ids := make([]int64, 0)
for i := range accounts {
if accounts[i].ProxyID == nil {
continue
}
id := *accounts[i].ProxyID
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
if len(ids) == 0 {
return []service.Proxy{}, nil
}
return h.adminService.GetProxiesByIDs(ctx, ids)
}
func parseAccountIDs(c *gin.Context) ([]int64, error) {
values := c.QueryArray("ids")
if len(values) == 0 {
raw := strings.TrimSpace(c.Query("ids"))
if raw != "" {
values = []string{raw}
}
}
if len(values) == 0 {
return nil, nil
}
ids := make([]int64, 0, len(values))
for _, item := range values {
for _, part := range strings.Split(item, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
id, err := strconv.ParseInt(part, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid account id: %s", part)
}
ids = append(ids, id)
}
}
return ids, nil
}
func parseIncludeProxies(c *gin.Context) (bool, error) {
raw := strings.TrimSpace(strings.ToLower(c.Query("include_proxies")))
if raw == "" {
return true, nil
}
switch raw {
case "1", "true", "yes", "on":
return true, nil
case "0", "false", "no", "off":
return false, nil
default:
return true, fmt.Errorf("invalid include_proxies value: %s", raw)
}
}
func validateDataHeader(payload DataPayload) error {
if payload.Type != "" && payload.Type != dataType && payload.Type != legacyDataType {
return fmt.Errorf("unsupported data type: %s", payload.Type)
}
if payload.Version != 0 && payload.Version != dataVersion {
return fmt.Errorf("unsupported data version: %d", payload.Version)
}
if payload.Proxies == nil {
return errors.New("proxies is required")
}
if payload.Accounts == nil {
return errors.New("accounts is required")
}
return nil
}
func validateDataProxy(item DataProxy) error {
if strings.TrimSpace(item.Protocol) == "" {
return errors.New("proxy protocol is required")
}
if strings.TrimSpace(item.Host) == "" {
return errors.New("proxy host is required")
}
if item.Port <= 0 || item.Port > 65535 {
return errors.New("proxy port is invalid")
}
switch item.Protocol {
case "http", "https", "socks5", "socks5h":
default:
return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol)
}
if item.Status != "" {
normalizedStatus := normalizeProxyStatus(item.Status)
if normalizedStatus != service.StatusActive && normalizedStatus != "inactive" {
return fmt.Errorf("proxy status is invalid: %s", item.Status)
}
}
return nil
}
func validateDataAccount(item DataAccount) error {
if strings.TrimSpace(item.Name) == "" {
return errors.New("account name is required")
}
if strings.TrimSpace(item.Platform) == "" {
return errors.New("account platform is required")
}
if strings.TrimSpace(item.Type) == "" {
return errors.New("account type is required")
}
if len(item.Credentials) == 0 {
return errors.New("account credentials is required")
}
switch item.Type {
case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream:
default:
return fmt.Errorf("account type is invalid: %s", item.Type)
}
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
return errors.New("rate_multiplier must be >= 0")
}
if item.Concurrency < 0 {
return errors.New("concurrency must be >= 0")
}
if item.Priority < 0 {
return errors.New("priority must be >= 0")
}
return nil
}
func defaultProxyName(name string) string {
if strings.TrimSpace(name) == "" {
return "imported-proxy"
}
return name
}
func normalizeProxyStatus(status string) string {
normalized := strings.TrimSpace(strings.ToLower(status))
switch normalized {
case "":
return ""
case service.StatusActive:
return service.StatusActive
case "inactive", service.StatusDisabled:
return "inactive"
default:
return normalized
}
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dataResponse struct {
Code int `json:"code"`
Data dataPayload `json:"data"`
}
type dataPayload struct {
Type string `json:"type"`
Version int `json:"version"`
Proxies []dataProxy `json:"proxies"`
Accounts []dataAccount `json:"accounts"`
}
type dataProxy struct {
ProxyKey string `json:"proxy_key"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Status string `json:"status"`
}
type dataAccount struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyKey *string `json:"proxy_key"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
}
func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)
router.POST("/api/v1/admin/accounts/data", h.ImportData)
return router, adminSvc
}
func TestExportDataIncludesSecrets(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
proxyID := int64(11)
adminSvc.proxies = []service.Proxy{
{
ID: proxyID,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 12,
Name: "orphan",
Protocol: "https",
Host: "10.0.0.1",
Port: 443,
Username: "o",
Password: "p",
Status: service.StatusActive,
},
}
adminSvc.accounts = []service.Account{
{
ID: 21,
Name: "account",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{"token": "secret"},
Extra: map[string]any{"note": "x"},
ProxyID: &proxyID,
Concurrency: 3,
Priority: 50,
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Empty(t, resp.Data.Type)
require.Equal(t, 0, resp.Data.Version)
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "pass", resp.Data.Proxies[0].Password)
require.Len(t, resp.Data.Accounts, 1)
require.Equal(t, "secret", resp.Data.Accounts[0].Credentials["token"])
}
func TestExportDataWithoutProxies(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
proxyID := int64(11)
adminSvc.proxies = []service.Proxy{
{
ID: proxyID,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
adminSvc.accounts = []service.Account{
{
ID: 21,
Name: "account",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{"token": "secret"},
ProxyID: &proxyID,
Concurrency: 3,
Priority: 50,
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data?include_proxies=false", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 0)
require.Len(t, resp.Data.Accounts, 1)
require.Nil(t, resp.Data.Accounts[0].ProxyKey)
}
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy",
Protocol: "socks5",
Host: "1.2.3.4",
Port: 1080,
Username: "u",
Password: "p",
Status: service.StatusActive,
},
}
dataPayload := map[string]any{
"data": map[string]any{
"type": dataType,
"version": dataVersion,
"proxies": []map[string]any{
{
"proxy_key": "socks5|1.2.3.4|1080|u|p",
"name": "proxy",
"protocol": "socks5",
"host": "1.2.3.4",
"port": 1080,
"username": "u",
"password": "p",
"status": "active",
},
},
"accounts": []map[string]any{
{
"name": "acc",
"platform": service.PlatformOpenAI,
"type": service.AccountTypeOAuth,
"credentials": map[string]any{"token": "x"},
"proxy_key": "socks5|1.2.3.4|1080|u|p",
"concurrency": 3,
"priority": 50,
},
},
},
"skip_default_group_bind": true,
}
body, _ := json.Marshal(dataPayload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdProxies, 0)
require.Len(t, adminSvc.createdAccounts, 1)
require.True(t, adminSvc.createdAccounts[0].SkipDefaultGroupBind)
}
...@@ -3,11 +3,13 @@ package admin ...@@ -3,11 +3,13 @@ package admin
import ( import (
"errors" "errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
...@@ -696,11 +698,61 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { ...@@ -696,11 +698,61 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return return
} }
// Return mock data for now ctx := c.Request.Context()
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{ response.Success(c, gin.H{
"success": len(req.Accounts), "success": success,
"failed": 0, "failed": failed,
"results": []gin.H{}, "results": results,
}) })
} }
...@@ -738,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { ...@@ -738,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
} }
ctx := c.Request.Context() ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs { for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID) account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil { if err != nil {
failed++ response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
results = append(results, gin.H{ return
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
} }
// Update credentials field
if account.Credentials == nil { if account.Credentials == nil {
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials[req.Field] = req.Value account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account // 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
for _, u := range updates {
updateInput := &service.UpdateAccountInput{ updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials, Credentials: u.Credentials,
} }
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
if err != nil { return
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
} }
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"success": success, "success": len(updates),
"failed": failed, "failed": 0,
"results": results,
}) })
} }
...@@ -1440,3 +1475,9 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { ...@@ -1440,3 +1475,9 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
response.Success(c, results) response.Success(c, results)
} }
// GetAntigravityDefaultModelMapping 获取 Antigravity 平台的默认模型映射
// GET /api/v1/admin/accounts/antigravity/default-model-mapping
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
...@@ -2,19 +2,27 @@ package admin ...@@ -2,19 +2,27 @@ package admin
import ( import (
"context" "context"
"strings"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
type stubAdminService struct { type stubAdminService struct {
users []service.User users []service.User
apiKeys []service.APIKey apiKeys []service.APIKey
groups []service.Group groups []service.Group
accounts []service.Account accounts []service.Account
proxies []service.Proxy proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
mu sync.Mutex
} }
func newStubAdminService() *stubAdminService { func newStubAdminService() *stubAdminService {
...@@ -177,6 +185,9 @@ func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([ ...@@ -177,6 +185,9 @@ func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([
} }
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) { func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil return &account, nil
} }
...@@ -214,7 +225,25 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic ...@@ -214,7 +225,25 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
} }
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
return s.proxies, int64(len(s.proxies)), nil search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies {
if protocol != "" && proxy.Protocol != protocol {
continue
}
if status != "" && proxy.Status != status {
continue
}
if search != "" {
name := strings.ToLower(proxy.Name)
host := strings.ToLower(proxy.Host)
if !strings.Contains(name, search) && !strings.Contains(host, search) {
continue
}
}
filtered = append(filtered, proxy)
}
return filtered, int64(len(filtered)), nil
} }
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
...@@ -230,16 +259,47 @@ func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([ ...@@ -230,16 +259,47 @@ func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([
} }
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) { func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
for i := range s.proxies {
proxy := s.proxies[i]
if proxy.ID == id {
return &proxy, nil
}
}
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive} proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil return &proxy, nil
} }
func (s *stubAdminService) GetProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
out := make([]service.Proxy, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
seen[id] = struct{}{}
}
for i := range s.proxies {
proxy := s.proxies[i]
if _, ok := seen[proxy.ID]; ok {
out = append(out, proxy)
}
}
return out, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) { func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
s.mu.Lock()
s.createdProxies = append(s.createdProxies, input)
s.mu.Unlock()
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive} proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil return &proxy, nil
} }
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) { func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
s.mu.Lock()
s.updatedProxyIDs = append(s.updatedProxyIDs, id)
s.updatedProxies = append(s.updatedProxies, input)
s.mu.Unlock()
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive} proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil return &proxy, nil
} }
...@@ -261,6 +321,9 @@ func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, po ...@@ -261,6 +321,9 @@ func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, po
} }
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) { func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
s.mu.Lock()
s.testedProxyIDs = append(s.testedProxyIDs, id)
s.mu.Unlock()
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
} }
......
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
// 让第 2 个账号(ID=2)更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
// 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用
require.Equal(t, int64(2), svc.updateCallCount.Load(),
"fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型(number),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null(设置为空),应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}
...@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { ...@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return return
} }
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage stats") response.Error(c, 500, "Failed to get user usage stats")
return return
...@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { ...@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage stats") response.Error(c, 500, "Failed to get API key usage stats")
return return
......
...@@ -63,6 +63,43 @@ func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) { ...@@ -63,6 +63,43 @@ func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) {
response.Success(c, payload) response.Success(c, payload)
} }
// GetUserConcurrencyStats returns real-time concurrency usage for all active users.
// GET /api/v1/admin/ops/user-concurrency
func (h *OpsHandler) GetUserConcurrencyStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"user": map[int64]*service.UserConcurrencyInfo{},
"timestamp": time.Now().UTC(),
})
return
}
users, collectedAt, err := h.opsService.GetUserConcurrencyStats(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"user": users,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
// GetAccountAvailability returns account availability statistics. // GetAccountAvailability returns account availability statistics.
// GET /api/v1/admin/ops/account-availability // GET /api/v1/admin/ops/account-availability
// //
......
package admin
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ExportData exports proxy-only data for migration.
func (h *ProxyHandler) ExportData(c *gin.Context) {
ctx := c.Request.Context()
selectedIDs, err := parseProxyIDs(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var proxies []service.Proxy
if len(selectedIDs) > 0 {
proxies, err = h.getProxiesByIDs(ctx, selectedIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
} else {
protocol := c.Query("protocol")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
dataProxies := make([]DataProxy, 0, len(proxies))
for i := range proxies {
p := proxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
dataProxies = append(dataProxies, DataProxy{
ProxyKey: key,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
})
}
payload := DataPayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Proxies: dataProxies,
Accounts: []DataAccount{},
}
response.Success(c, payload)
}
// ImportData imports proxy-only data for migration.
func (h *ProxyHandler) ImportData(c *gin.Context) {
type ProxyImportRequest struct {
Data DataPayload `json:"data"`
}
var req ProxyImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
ctx := c.Request.Context()
result := DataImportResult{}
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "")
if err != nil {
response.ErrorFrom(c, err)
return
}
proxyByKey := make(map[string]service.Proxy, len(existingProxies))
for i := range existingProxies {
p := existingProxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyByKey[key] = p
}
latencyProbeIDs := make([]int64, 0, len(req.Data.Proxies))
for i := range req.Data.Proxies {
item := req.Data.Proxies[i]
key := item.ProxyKey
if key == "" {
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
}
if err := validateDataProxy(item); err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
normalizedStatus := normalizeProxyStatus(item.Status)
if existing, ok := proxyByKey[key]; ok {
result.ProxyReused++
if normalizedStatus != "" && normalizedStatus != existing.Status {
if _, err := h.adminService.UpdateProxy(ctx, existing.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: "update status failed: " + err.Error(),
})
}
}
latencyProbeIDs = append(latencyProbeIDs, existing.ID)
continue
}
created, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
result.ProxyCreated++
proxyByKey[key] = *created
if normalizedStatus != "" && normalizedStatus != created.Status {
if _, err := h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: "update status failed: " + err.Error(),
})
}
}
// CreateProxy already triggers a latency probe, avoid double probing here.
}
if len(latencyProbeIDs) > 0 {
ids := append([]int64(nil), latencyProbeIDs...)
go func() {
for _, id := range ids {
_, _ = h.adminService.TestProxy(context.Background(), id)
}
}()
}
response.Success(c, result)
}
func (h *ProxyHandler) getProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
return h.adminService.GetProxiesByIDs(ctx, ids)
}
func parseProxyIDs(c *gin.Context) ([]int64, error) {
values := c.QueryArray("ids")
if len(values) == 0 {
raw := strings.TrimSpace(c.Query("ids"))
if raw != "" {
values = []string{raw}
}
}
if len(values) == 0 {
return nil, nil
}
ids := make([]int64, 0, len(values))
for _, item := range values {
for _, part := range strings.Split(item, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
id, err := strconv.ParseInt(part, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid proxy id: %s", part)
}
ids = append(ids, id)
}
}
return ids, nil
}
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type proxyDataResponse struct {
Code int `json:"code"`
Data DataPayload `json:"data"`
}
type proxyImportResponse struct {
Code int `json:"code"`
Data DataImportResult `json:"data"`
}
func setupProxyDataRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewProxyHandler(adminSvc)
router.GET("/api/v1/admin/proxies/data", h.ExportData)
router.POST("/api/v1/admin/proxies/data", h.ImportData)
return router, adminSvc
}
func TestProxyExportDataRespectsFilters(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-b",
Protocol: "https",
Host: "10.0.0.2",
Port: 443,
Username: "u",
Password: "p",
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=https", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Empty(t, resp.Data.Type)
require.Equal(t, 0, resp.Data.Version)
require.Len(t, resp.Data.Proxies, 1)
require.Len(t, resp.Data.Accounts, 0)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
}
func TestProxyExportDataWithSelectedIDs(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-b",
Protocol: "https",
Host: "10.0.0.2",
Port: 443,
Username: "u",
Password: "p",
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?ids=2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
}
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
payload := map[string]any{
"data": map[string]any{
"type": dataType,
"version": dataVersion,
"proxies": []map[string]any{
{
"proxy_key": "http|127.0.0.1|8080|user|pass",
"name": "proxy-a",
"protocol": "http",
"host": "127.0.0.1",
"port": 8080,
"username": "user",
"password": "pass",
"status": "inactive",
},
{
"proxy_key": "https|10.0.0.2|443|u|p",
"name": "proxy-b",
"protocol": "https",
"host": "10.0.0.2",
"port": 443,
"username": "u",
"password": "p",
"status": "active",
},
},
"accounts": []map[string]any{},
},
}
body, _ := json.Marshal(payload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyImportResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, 1, resp.Data.ProxyCreated)
require.Equal(t, 1, resp.Data.ProxyReused)
require.Equal(t, 0, resp.Data.ProxyFailed)
adminSvc.mu.Lock()
updatedIDs := append([]int64(nil), adminSvc.updatedProxyIDs...)
adminSvc.mu.Unlock()
require.Contains(t, updatedIDs, int64(1))
require.Eventually(t, func() bool {
adminSvc.mu.Lock()
defer adminSvc.mu.Unlock()
return len(adminSvc.testedProxyIDs) == 1
}, time.Second, 10*time.Millisecond)
}
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}
...@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) { ...@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search") search := c.Query("search")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if runes := []rune(search); len(runes) > 100 {
search = search[:100] search = string(runes[:100])
} }
filters := service.UserListFilters{ filters := service.UserListFilters{
......
...@@ -2,6 +2,7 @@ package handler ...@@ -2,6 +2,7 @@ package handler
import ( import (
"log/slog" "log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { ...@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return return
} }
// Build frontend base URL from request frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
scheme := "https" if frontendBaseURL == "" {
if c.Request.TLS == nil { slog.Error("server.frontend_url not configured; cannot build password reset link")
// Check X-Forwarded-Proto header (common in reverse proxy setups) response.InternalError(c, "Password reset is not configured")
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { return
scheme = proto
} else {
scheme = "http"
}
} }
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async) // Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration) // Note: This returns success even if email doesn't exist (to prevent enumeration)
......
...@@ -215,17 +215,6 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -215,17 +215,6 @@ func AccountFromServiceShallow(a *service.Account) *Account {
} }
} }
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
now := time.Now()
for scope, remainingSec := range scopeLimits {
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
RemainingSec: remainingSec,
}
}
}
return out return out
} }
......
...@@ -2,6 +2,7 @@ package handler ...@@ -2,6 +2,7 @@ package handler
import ( import (
"context" "context"
"crypto/rand"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -113,9 +114,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -113,9 +114,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body) parsedReq, err := service.ParseGatewayRequest(body)
...@@ -126,6 +124,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -126,6 +124,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model reqModel := parsedReq.Model
reqStream := parsedReq.Stream reqStream := parsedReq.Stream
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
c.Request = c.Request.WithContext(ctx)
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
// 验证 model 必填 // 验证 model 必填
...@@ -137,6 +149,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -137,6 +149,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Track if we've started streaming (for error handling) // Track if we've started streaming (for error handling)
streamStarted := false streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
// 获取订阅信息(可能为nil)- 提前获取用于后续检查 // 获取订阅信息(可能为nil)- 提前获取用于后续检查
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
...@@ -202,17 +219,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -202,17 +219,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
sessionKey = "gemini:" + sessionHash sessionKey = "gemini:" + sessionHash
} }
// 查询粘性会话绑定的账号 ID
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
if lastFailoverErr != nil { if lastFailoverErr != nil {
...@@ -227,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -227,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查请求拦截(预热请求、SUGGESTION MODE等) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() { if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body) interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
if interceptType != InterceptTypeNone { if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil { if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc() selection.ReleaseFunc()
...@@ -260,12 +287,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -260,12 +287,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
// Ensure the wait counter is decremented if we exit before acquiring the slot. releaseWait := func() {
defer func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
} }
}() }
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -277,14 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -277,14 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
// Slot acquired: no longer waiting in queue. // Slot acquired: no longer waiting in queue.
if accountWaitCounted { releaseWait()
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
...@@ -299,7 +324,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -299,7 +324,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
} }
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
} else { } else {
result, err = h.geminiCompatService.Forward(requestCtx, c, account, body) result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
} }
...@@ -311,6 +336,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -311,6 +336,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return return
...@@ -329,22 +357,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -329,22 +357,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 异步记录使用量(subscription已在函数开头获取) // 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent, clientIP) }(result, account, userAgent, clientIP, forceCacheBilling)
return return
} }
} }
...@@ -363,13 +392,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -363,13 +392,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
if lastFailoverErr != nil { if lastFailoverErr != nil {
...@@ -384,7 +415,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -384,7 +415,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查请求拦截(预热请求、SUGGESTION MODE等) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() { if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body) interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
if interceptType != InterceptTypeNone { if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil { if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc() selection.ReleaseFunc()
...@@ -417,11 +448,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -417,11 +448,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
defer func() { releaseWait := func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
} }
}() }
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -433,13 +465,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -433,13 +465,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountWaitCounted { // Slot acquired: no longer waiting in queue.
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) releaseWait()
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
...@@ -454,7 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -454,7 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
} }
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body) result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else { } else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
} }
...@@ -501,6 +532,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -501,6 +532,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return return
...@@ -519,22 +553,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -519,22 +553,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 异步记录使用量(subscription已在函数开头获取) // 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: currentAPIKey, APIKey: currentAPIKey,
User: currentAPIKey.User, User: currentAPIKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: currentSubscription, Subscription: currentSubscription,
UserAgent: ua, UserAgent: ua,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent, clientIP) }(result, account, userAgent, clientIP, forceCacheBilling)
return return
} }
if !retryWithFallback { if !retryWithFallback {
...@@ -917,6 +952,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -917,6 +952,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
// 验证 model 必填 // 验证 model 必填
if parsedReq.Model == "" { if parsedReq.Model == "" {
...@@ -943,7 +980,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -943,7 +980,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return return
} }
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
...@@ -960,13 +998,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -960,13 +998,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
type InterceptType int type InterceptType int
const ( const (
InterceptTypeNone InterceptType = iota InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation") InterceptTypeWarmup // 预热请求(返回 "New Conversation")
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串) InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#")
) )
// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感)
func isHaikuModel(model string) bool {
return strings.Contains(strings.ToLower(model), "haiku")
}
// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求
// 这类请求用于 Claude Code 验证 API 连通性
// 条件:max_tokens == 1 且 model 包含 "haiku" 且非流式请求
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
return maxTokens == 1 && isHaikuModel(model) && !isStream
}
// detectInterceptType 检测请求是否需要拦截,返回拦截类型 // detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType { // 参数说明:
// - body: 请求体字节
// - model: 请求的模型名称
// - maxTokens: max_tokens 值
// - isStream: 是否为流式请求
// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
// 优先检查 max_tokens=1 + haiku 探测请求(仅非流式)
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
return InterceptTypeMaxTokensOneHaiku
}
// 快速检查:如果不包含任何关键字,直接返回 // 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body) bodyStr := string(body)
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:") hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
...@@ -1116,9 +1178,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce ...@@ -1116,9 +1178,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
} }
} }
// generateRealisticMsgID 生成仿真的消息 ID(msg_bdrk_XXXXXXX 格式)
// 格式与 Claude API 真实响应一致,24 位随机字母数字
func generateRealisticMsgID() string {
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
const idLen = 24
randomBytes := make([]byte, idLen)
if _, err := rand.Read(randomBytes); err != nil {
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
}
b := make([]byte, idLen)
for i := range b {
b[i] = charset[int(randomBytes[i])%len(charset)]
}
return "msg_bdrk_" + string(b)
}
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截) // sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) { func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string var msgID, text, stopReason string
var outputTokens int var outputTokens int
switch interceptType { switch interceptType {
...@@ -1126,24 +1204,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter ...@@ -1126,24 +1204,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
msgID = "msg_mock_suggestion" msgID = "msg_mock_suggestion"
text = "" text = ""
outputTokens = 1 outputTokens = 1
stopReason = "end_turn"
case InterceptTypeMaxTokensOneHaiku:
msgID = generateRealisticMsgID()
text = "#"
outputTokens = 1
stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens
default: // InterceptTypeWarmup default: // InterceptTypeWarmup
msgID = "msg_mock_warmup" msgID = "msg_mock_warmup"
text = "New Conversation" text = "New Conversation"
outputTokens = 2 outputTokens = 2
stopReason = "end_turn"
} }
c.JSON(http.StatusOK, gin.H{ // 构建完整的响应格式(与 Claude API 响应格式一致)
"id": msgID, response := gin.H{
"type": "message", "model": model,
"role": "assistant", "id": msgID,
"model": model, "type": "message",
"content": []gin.H{{"type": "text", "text": text}}, "role": "assistant",
"stop_reason": "end_turn", "content": []gin.H{{"type": "text", "text": text}},
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": gin.H{ "usage": gin.H{
"input_tokens": 10, "input_tokens": 10,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"cache_creation": gin.H{
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"output_tokens": outputTokens, "output_tokens": outputTokens,
"total_tokens": 10 + outputTokens,
}, },
}) }
c.JSON(http.StatusOK, response)
} }
func billingErrorDetails(err error) (status int, code, message string) { func billingErrorDetails(err error) (status int, code, message string) {
...@@ -1156,7 +1252,8 @@ func billingErrorDetails(err error) (status int, code, message string) { ...@@ -1156,7 +1252,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
} }
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
if msg == "" { if msg == "" {
msg = err.Error() log.Printf("[Gateway] billing error details: %v", err)
msg = "Billing error"
} }
return http.StatusForbidden, "billing_error", msg return http.StatusForbidden, "billing_error", msg
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment