Commit 195e227c authored by song's avatar song
Browse files

merge: 合并 upstream/main 并保留本地图片计费功能

parents 6fa704d6 752882a0
...@@ -76,6 +76,7 @@ func NewAccountHandler( ...@@ -76,6 +76,7 @@ func NewAccountHandler(
// CreateAccountRequest represents create account request // CreateAccountRequest represents create account request
type CreateAccountRequest struct { type CreateAccountRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"` Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"` Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"` Credentials map[string]any `json:"credentials" binding:"required"`
...@@ -91,6 +92,7 @@ type CreateAccountRequest struct { ...@@ -91,6 +92,7 @@ type CreateAccountRequest struct {
// 使用指针类型来区分"未提供"和"设置为0" // 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"` Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
...@@ -193,6 +195,7 @@ func (h *AccountHandler) Create(c *gin.Context) { ...@@ -193,6 +195,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name, Name: req.Name,
Notes: req.Notes,
Platform: req.Platform, Platform: req.Platform,
Type: req.Type, Type: req.Type,
Credentials: req.Credentials, Credentials: req.Credentials,
...@@ -249,6 +252,7 @@ func (h *AccountHandler) Update(c *gin.Context) { ...@@ -249,6 +252,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name, Name: req.Name,
Notes: req.Notes,
Type: req.Type, Type: req.Type,
Credentials: req.Credentials, Credentials: req.Credentials,
Extra: req.Extra, Extra: req.Extra,
...@@ -357,7 +361,8 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) { ...@@ -357,7 +361,8 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
SyncProxies: syncProxies, SyncProxies: syncProxies,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) // Provide detailed error message for CRS sync failures
response.InternalError(c, "CRS sync failed: "+err.Error())
return return
} }
......
package admin package admin
import ( import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -34,31 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -34,31 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
} }
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost, SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort, SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername, SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword, SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom, SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName, SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS, SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey, TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName, SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo, SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle, SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL, APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback, EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI, FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini, FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity, FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
}) })
} }
...@@ -100,6 +106,10 @@ type UpdateSettingsRequest struct { ...@@ -100,6 +106,10 @@ type UpdateSettingsRequest struct {
FallbackModelOpenAI string `json:"fallback_model_openai"` FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"` FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"` FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
...@@ -111,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -111,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数 // 验证参数
if req.DefaultConcurrency < 1 { if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1 req.DefaultConcurrency = 1
...@@ -129,21 +145,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -129,21 +145,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "Turnstile Site Key is required when enabled") response.BadRequest(c, "Turnstile Site Key is required when enabled")
return return
} }
// 如果未提供 secret key,使用已保存的值(留空保留当前值)
if req.TurnstileSecretKey == "" { if req.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled") if previousSettings.TurnstileSecretKey == "" {
return response.BadRequest(c, "Turnstile Secret Key is required when enabled")
} return
}
// 获取当前设置,检查参数是否有变化 req.TurnstileSecretKey = previousSettings.TurnstileSecretKey
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
} }
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录) // 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey siteKeyChanged := previousSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey secretKeyChanged := previousSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged { if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil { if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -178,6 +191,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -178,6 +191,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelOpenAI: req.FallbackModelOpenAI, FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini, FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity, FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
...@@ -185,6 +200,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -185,6 +200,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
h.auditSettingsUpdate(c, previousSettings, settings, req)
// 重新获取设置返回 // 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil { if err != nil {
...@@ -193,34 +210,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -193,34 +210,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost, SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort, SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername, SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword, SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom, SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName, SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS, SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled, TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey, TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey, TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName, SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo, SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle, SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL, APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo, ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL, DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback, EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini, FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
}) })
} }
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
}
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
if before.SMTPPort != after.SMTPPort {
changed = append(changed, "smtp_port")
}
if before.SMTPUsername != after.SMTPUsername {
changed = append(changed, "smtp_username")
}
if req.SMTPPassword != "" {
changed = append(changed, "smtp_password")
}
if before.SMTPFrom != after.SMTPFrom {
changed = append(changed, "smtp_from_email")
}
if before.SMTPFromName != after.SMTPFromName {
changed = append(changed, "smtp_from_name")
}
if before.SMTPUseTLS != after.SMTPUseTLS {
changed = append(changed, "smtp_use_tls")
}
if before.TurnstileEnabled != after.TurnstileEnabled {
changed = append(changed, "turnstile_enabled")
}
if before.TurnstileSiteKey != after.TurnstileSiteKey {
changed = append(changed, "turnstile_site_key")
}
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
if before.SiteLogo != after.SiteLogo {
changed = append(changed, "site_logo")
}
if before.SiteSubtitle != after.SiteSubtitle {
changed = append(changed, "site_subtitle")
}
if before.APIBaseURL != after.APIBaseURL {
changed = append(changed, "api_base_url")
}
if before.ContactInfo != after.ContactInfo {
changed = append(changed, "contact_info")
}
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
changed = append(changed, "fallback_model_anthropic")
}
if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
changed = append(changed, "fallback_model_openai")
}
if before.FallbackModelGemini != after.FallbackModelGemini {
changed = append(changed, "fallback_model_gemini")
}
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
return changed
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host" binding:"required"`
......
...@@ -109,6 +109,7 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -109,6 +109,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
return &Account{ return &Account{
ID: a.ID, ID: a.ID,
Name: a.Name, Name: a.Name,
Notes: a.Notes,
Platform: a.Platform, Platform: a.Platform,
Type: a.Type, Type: a.Type,
Credentials: a.Credentials, Credentials: a.Credentials,
......
...@@ -5,17 +5,17 @@ type SystemSettings struct { ...@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"` SMTPPasswordConfigured bool `json:"smtp_password_configured"`
SMTPFrom string `json:"smtp_from_email"` SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"` SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"` SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
...@@ -33,6 +33,10 @@ type SystemSettings struct { ...@@ -33,6 +33,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"` FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"` FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"` FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
} }
type PublicSettings struct { type PublicSettings struct {
......
...@@ -62,6 +62,7 @@ type Group struct { ...@@ -62,6 +62,7 @@ type Group struct {
type Account struct { type Account struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"` Platform string `json:"platform"`
Type string `json:"type"` Type string `json:"type"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
......
...@@ -11,8 +11,10 @@ import ( ...@@ -11,8 +11,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -38,14 +40,19 @@ func NewGatewayHandler( ...@@ -38,14 +40,19 @@ func NewGatewayHandler(
userService *service.UserService, userService *service.UserService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &GatewayHandler{ return &GatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
geminiCompatService: geminiCompatService, geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
} }
} }
...@@ -121,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -121,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
return return
} }
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil { if userReleaseFunc != nil {
defer userReleaseFunc() defer userReleaseFunc()
} }
...@@ -128,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -128,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅 // 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
...@@ -220,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -220,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
...@@ -344,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -344,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
...@@ -674,7 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -674,7 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额) // 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额 // 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error()) status, code, message := billingErrorDetails(err)
h.errorResponse(c, status, code, message)
return return
} }
...@@ -800,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) { ...@@ -800,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
}, },
}) })
} }
func billingErrorDetails(err error) (status int, code, message string) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
}
return http.StatusForbidden, "billing_error", msg
}
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -26,8 +27,8 @@ import ( ...@@ -26,8 +27,8 @@ import (
const ( const (
// maxConcurrencyWait 等待并发槽位的最大时间 // maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second maxConcurrencyWait = 30 * time.Second
// pingInterval 流式响应等待时发送 ping 的间隔 // defaultPingInterval 流式响应等待时发送 ping 的默认间隔
pingInterval = 15 * time.Second defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间 // initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避) // backoffMultiplier 退避时间乘数(指数退避)
...@@ -44,6 +45,8 @@ const ( ...@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n" SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec) // SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = "" SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
) )
// ConcurrencyError represents a concurrency limit error with context // ConcurrencyError represents a concurrency limit error with context
...@@ -63,14 +66,36 @@ func (e *ConcurrencyError) Error() string { ...@@ -63,14 +66,36 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct { type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat pingFormat SSEPingFormat
pingInterval time.Duration
} }
// NewConcurrencyHelper creates a new ConcurrencyHelper // NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper { func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{ return &ConcurrencyHelper{
concurrencyService: concurrencyService, concurrencyService: concurrencyService,
pingFormat: pingFormat, pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
} }
go func() {
<-ctx.Done()
wrapped()
}()
return wrapped
} }
// IncrementWaitCount increments the wait count for a user // IncrementWaitCount increments the wait count for a user
...@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed // Only create ping ticker if ping is needed
var pingCh <-chan time.Time var pingCh <-chan time.Time
if needPing { if needPing {
pingTicker := time.NewTicker(pingInterval) pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop() defer pingTicker.Stop()
pingCh = pingTicker.C pingCh = pingTicker.C
} }
......
...@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames. // For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check // 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency) maxWait := service.CalculateMaxWait(authSubject.Concurrency)
...@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil { if userReleaseFunc != nil {
defer userReleaseFunc() defer userReleaseFunc()
} }
// 2) billing eligibility check (after wait) // 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error()) status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return return
} }
...@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
var result *service.ForwardResult var result *service.ForwardResult
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler( ...@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &OpenAIGatewayHandler{ return &OpenAIGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
} }
} }
...@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
return return
} }
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil { if userReleaseFunc != nil {
defer userReleaseFunc() defer userReleaseFunc()
} }
...@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
...@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request // Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
......
...@@ -4,13 +4,34 @@ import ( ...@@ -4,13 +4,34 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"os"
"strings" "strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
type TransformOptions struct {
EnableIdentityPatch bool
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch string
}
func DefaultTransformOptions() TransformOptions {
return TransformOptions{
EnableIdentityPatch: true,
}
}
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式 // TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) { func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
}
// TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为)
func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) {
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
...@@ -22,16 +43,24 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st ...@@ -22,16 +43,24 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 1. 构建 contents // 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil { if err != nil {
return nil, fmt.Errorf("build contents: %w", err) return nil, fmt.Errorf("build contents: %w", err)
} }
// 2. 构建 systemInstruction // 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
// 3. 构建 generationConfig // 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq) reqForConfig := claudeReq
if strippedThinking {
// If we had to downgrade thinking blocks to plain text due to missing/invalid signatures,
// disable upstream thinking mode to avoid signature/structure validation errors.
reqCopy := *claudeReq
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools // 4. 构建 tools
tools := buildTools(claudeReq.Tools) tools := buildTools(claudeReq.Tools)
...@@ -75,12 +104,8 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st ...@@ -75,12 +104,8 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
return json.Marshal(v1Req) return json.Marshal(v1Req)
} }
// buildSystemInstruction 构建 systemInstruction func defaultIdentityPatch(modelName string) string {
func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiContent { return fmt.Sprintf(
var parts []GeminiPart
// 注入身份防护指令
identityPatch := fmt.Sprintf(
"--- [IDENTITY_PATCH] ---\n"+ "--- [IDENTITY_PATCH] ---\n"+
"Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+ "Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+
"You are currently providing services as the native %s model via a standard API proxy.\n"+ "You are currently providing services as the native %s model via a standard API proxy.\n"+
...@@ -88,7 +113,20 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon ...@@ -88,7 +113,20 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
"--- [SYSTEM_PROMPT_BEGIN] ---\n", "--- [SYSTEM_PROMPT_BEGIN] ---\n",
modelName, modelName,
) )
parts = append(parts, GeminiPart{Text: identityPatch}) }
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
var parts []GeminiPart
// 可选注入身份防护指令(身份补丁)
if opts.EnableIdentityPatch {
identityPatch := strings.TrimSpace(opts.IdentityPatch)
if identityPatch == "" {
identityPatch = defaultIdentityPatch(modelName)
}
parts = append(parts, GeminiPart{Text: identityPatch})
}
// 解析 system prompt // 解析 system prompt
if len(system) > 0 { if len(system) > 0 {
...@@ -111,7 +149,13 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon ...@@ -111,7 +149,13 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
} }
} }
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"}) // identity patch 模式下,用分隔符包裹 system prompt,便于上游识别/调试;关闭时尽量保持原始 system prompt。
if opts.EnableIdentityPatch && len(parts) > 0 {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 {
return nil
}
return &GeminiContent{ return &GeminiContent{
Role: "user", Role: "user",
...@@ -120,8 +164,9 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon ...@@ -120,8 +164,9 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
} }
// buildContents 构建 contents // buildContents 构建 contents
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, error) { func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, bool, error) {
var contents []GeminiContent var contents []GeminiContent
strippedThinking := false
for i, msg := range messages { for i, msg := range messages {
role := msg.Role role := msg.Role
...@@ -129,9 +174,12 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -129,9 +174,12 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
role = "model" role = "model"
} }
parts, err := buildParts(msg.Content, toolIDToName, allowDummyThought) parts, strippedThisMsg, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
if err != nil { if err != nil {
return nil, fmt.Errorf("build parts for message %d: %w", i, err) return nil, false, fmt.Errorf("build parts for message %d: %w", i, err)
}
if strippedThisMsg {
strippedThinking = true
} }
// 只有 Gemini 模型支持 dummy thinking block workaround // 只有 Gemini 模型支持 dummy thinking block workaround
...@@ -165,7 +213,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -165,7 +213,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
}) })
} }
return contents, nil return contents, strippedThinking, nil
} }
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证 // dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
...@@ -174,8 +222,9 @@ const dummyThoughtSignature = "skip_thought_signature_validator" ...@@ -174,8 +222,9 @@ const dummyThoughtSignature = "skip_thought_signature_validator"
// buildParts 构建消息的 parts // buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, bool, error) {
var parts []GeminiPart var parts []GeminiPart
strippedThinking := false
// 尝试解析为字符串 // 尝试解析为字符串
var textContent string var textContent string
...@@ -183,13 +232,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -183,13 +232,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
if textContent != "(no content)" && strings.TrimSpace(textContent) != "" { if textContent != "(no content)" && strings.TrimSpace(textContent) != "" {
parts = append(parts, GeminiPart{Text: strings.TrimSpace(textContent)}) parts = append(parts, GeminiPart{Text: strings.TrimSpace(textContent)})
} }
return parts, nil return parts, false, nil
} }
// 解析为内容块数组 // 解析为内容块数组
var blocks []ContentBlock var blocks []ContentBlock
if err := json.Unmarshal(content, &blocks); err != nil { if err := json.Unmarshal(content, &blocks); err != nil {
return nil, fmt.Errorf("parse content blocks: %w", err) return nil, false, fmt.Errorf("parse content blocks: %w", err)
} }
for _, block := range blocks { for _, block := range blocks {
...@@ -208,8 +257,11 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -208,8 +257,11 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
if block.Signature != "" { if block.Signature != "" {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if !allowDummyThought { } else if !allowDummyThought {
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
log.Printf("Warning: skipping thinking block without signature for Claude model") if strings.TrimSpace(block.Thinking) != "" {
parts = append(parts, GeminiPart{Text: block.Thinking})
}
strippedThinking = true
continue continue
} else { } else {
// Gemini 模型使用 dummy signature // Gemini 模型使用 dummy signature
...@@ -276,7 +328,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -276,7 +328,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
} }
} }
return parts, nil return parts, strippedThinking, nil
} }
// parseToolResultContent 解析 tool_result 的 content // parseToolResultContent 解析 tool_result 的 content
...@@ -446,7 +498,7 @@ func cleanJSONSchema(schema map[string]any) map[string]any { ...@@ -446,7 +498,7 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil { if schema == nil {
return nil return nil
} }
cleaned := cleanSchemaValue(schema) cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any) result, ok := cleaned.(map[string]any)
if !ok { if !ok {
return nil return nil
...@@ -484,6 +536,56 @@ func cleanJSONSchema(schema map[string]any) map[string]any { ...@@ -484,6 +536,56 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
return result return result
} }
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出(debug/test)
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段 // excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况 // 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items // 支持: type, description, enum, properties, required, additionalProperties, items
...@@ -546,13 +648,14 @@ var excludedSchemaKeys = map[string]bool{ ...@@ -546,13 +648,14 @@ var excludedSchemaKeys = map[string]bool{
} }
// cleanSchemaValue 递归清理 schema 值 // cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any) any { func cleanSchemaValue(value any, path string) any {
switch v := value.(type) { switch v := value.(type) {
case map[string]any: case map[string]any:
result := make(map[string]any) result := make(map[string]any)
for k, val := range v { for k, val := range v {
// 跳过不支持的字段 // 跳过不支持的字段
if excludedSchemaKeys[k] { if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue continue
} }
...@@ -586,15 +689,15 @@ func cleanSchemaValue(value any) any { ...@@ -586,15 +689,15 @@ func cleanSchemaValue(value any) any {
} }
// 递归清理所有值 // 递归清理所有值
result[k] = cleanSchemaValue(val) result[k] = cleanSchemaValue(val, path+"."+k)
} }
return result return result
case []any: case []any:
// 递归处理数组中的每个元素 // 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v)) cleaned := make([]any, 0, len(v))
for _, item := range v { for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item)) cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
} }
return cleaned return cleaned
......
...@@ -15,15 +15,15 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { ...@@ -15,15 +15,15 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
description string description string
}{ }{
{ {
name: "Claude model - drop thinking without signature", name: "Claude model - downgrade thinking to text without signature",
content: `[ content: `[
{"type": "text", "text": "Hello"}, {"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""}, {"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"} {"type": "text", "text": "World"}
]`, ]`,
allowDummyThought: false, allowDummyThought: false,
expectedParts: 2, // thinking 内容被丢弃 expectedParts: 3, // thinking 内容降级为普通 text part
description: "Claude模型应丢弃无signaturethinking block内容", description: "Claude模型缺少signature时应将thinking降级为text,并在上层禁用thinking mode",
}, },
{ {
name: "Claude model - preserve thinking block with signature", name: "Claude model - preserve thinking block with signature",
...@@ -52,7 +52,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { ...@@ -52,7 +52,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought) parts, _, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil { if err != nil {
t.Fatalf("buildParts() error = %v", err) t.Fatalf("buildParts() error = %v", err)
...@@ -71,6 +71,17 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { ...@@ -71,6 +71,17 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q", t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature) parts[1].Thought, parts[1].ThoughtSignature)
} }
case "Claude model - downgrade thinking to text without signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
if parts[1].Thought {
t.Fatalf("expected downgraded text part, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
if parts[1].Text != "Let me think..." {
t.Fatalf("expected downgraded text %q, got %q", "Let me think...", parts[1].Text)
}
case "Gemini model - use dummy signature": case "Gemini model - use dummy signature":
if len(parts) != 3 { if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts)) t.Fatalf("expected 3 parts, got %d", len(parts))
...@@ -91,7 +102,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -91,7 +102,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) { t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, true) parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil { if err != nil {
t.Fatalf("buildParts() error = %v", err) t.Fatalf("buildParts() error = %v", err)
} }
...@@ -105,7 +116,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -105,7 +116,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) { t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) {
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, false) parts, _, err := buildParts(json.RawMessage(content), toolIDToName, false)
if err != nil { if err != nil {
t.Fatalf("buildParts() error = %v", err) t.Fatalf("buildParts() error = %v", err)
} }
......
...@@ -25,13 +25,14 @@ import ( ...@@ -25,13 +25,14 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
// Transport 连接池默认配置 // Transport 连接池默认配置
const ( const (
defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
) )
// Options 定义共享 HTTP 客户端的构建参数 // Options 定义共享 HTTP 客户端的构建参数
...@@ -40,6 +41,9 @@ type Options struct { ...@@ -40,6 +41,9 @@ type Options struct {
Timeout time.Duration // 请求总超时时间 Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证 InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值) // 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100) MaxIdleConns int // 最大空闲连接总数(默认 100)
...@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) { ...@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err return nil, err
} }
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
}
return &http.Client{ return &http.Client{
Transport: transport, Transport: rt,
Timeout: opts.Timeout, Timeout: opts.Timeout,
}, nil }, nil
} }
...@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) { ...@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
} }
func buildClientKey(opts Options) string { func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d", return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL), strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(), opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(), opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify, opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns, opts.MaxIdleConns,
opts.MaxIdleConnsPerHost, opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost, opts.MaxConnsPerHost,
) )
} }
type validatedTransport struct {
base http.RoundTripper
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
}
}
}
return t.base.RoundTrip(req)
}
...@@ -67,6 +67,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -67,6 +67,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
builder := r.client.Account.Create(). builder := r.client.Account.Create().
SetName(account.Name). SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform). SetPlatform(account.Platform).
SetType(account.Type). SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)). SetCredentials(normalizeJSONMap(account.Credentials)).
...@@ -270,6 +271,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -270,6 +271,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
builder := r.client.Account.UpdateOneID(account.ID). builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name). SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform). SetPlatform(account.Platform).
SetType(account.Type). SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)). SetCredentials(normalizeJSONMap(account.Credentials)).
...@@ -320,6 +322,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -320,6 +322,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} else { } else {
builder.ClearSessionWindowStatus() builder.ClearSessionWindowStatus()
} }
if account.Notes == nil {
builder.ClearNotes()
}
updated, err := builder.Save(ctx) updated, err := builder.Save(ctx)
if err != nil { if err != nil {
...@@ -768,9 +773,14 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -768,9 +773,14 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
idx++ idx++
} }
if updates.ProxyID != nil { if updates.ProxyID != nil {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx)) // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
args = append(args, *updates.ProxyID) if *updates.ProxyID == 0 {
idx++ setClauses = append(setClauses, "proxy_id = NULL")
} else {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
}
} }
if updates.Concurrency != nil { if updates.Concurrency != nil {
setClauses = append(setClauses, "concurrency = $"+itoa(idx)) setClauses = append(setClauses, "concurrency = $"+itoa(idx))
...@@ -1065,6 +1075,7 @@ func accountEntityToService(m *dbent.Account) *service.Account { ...@@ -1065,6 +1075,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return &service.Account{ return &service.Account{
ID: m.ID, ID: m.ID,
Name: m.Name, Name: m.Name,
Notes: m.Notes,
Platform: m.Platform, Platform: m.Platform,
Type: m.Type, Type: m.Type,
Credentials: copyJSONMap(m.Credentials), Credentials: copyJSONMap(m.Credentials),
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
...@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey ...@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256", "code_challenge_method": "S256",
} }
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct { var result struct {
...@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState fullCode = authCode + "#" + responseState
} }
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil return fullCode, nil
} }
...@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
} }
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse var tokenResp oauth.TokenResponse
...@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
...@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client { ...@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client return client
} }
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}
...@@ -15,7 +15,8 @@ import ( ...@@ -15,7 +15,8 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage" const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct { type claudeUsageService struct {
usageURL string usageURL string
allowPrivateHosts bool
} }
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
...@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { ...@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 30 * time.Second} client = &http.Client{Timeout: 30 * time.Second}
......
...@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { ...@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`) }`)
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage") require.NoError(s.T(), err, "FetchUsage")
...@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { ...@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope") _, _ = io.WriteString(w, "nope")
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "") _, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err) require.Error(s.T(), err)
...@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { ...@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json") _, _ = io.WriteString(w, "not-json")
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "") _, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err) require.Error(s.T(), err)
...@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { ...@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done() <-r.Context().Done()
})) }))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL} s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately cancel() // Cancel immediately
......
...@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { ...@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 确保数据库 schema 已准备就绪。 // 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。 // SQL 迁移文件是 schema 的权威来源(source of truth)。
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。 // 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel() defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil { if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露 _ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
......
...@@ -14,18 +14,23 @@ import ( ...@@ -14,18 +14,23 @@ import (
) )
type githubReleaseClient struct { type githubReleaseClient struct {
httpClient *http.Client httpClient *http.Client
allowPrivateHosts bool
} }
func NewGitHubReleaseClient() service.GitHubReleaseClient { func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
} }
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: sharedClient, httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
} }
} }
...@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
} }
downloadClient, err := httpclient.GetClient(httpclient.Options{ downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute, Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
}) })
if err != nil { if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute} downloadClient = &http.Client{Timeout: 10 * time.Minute}
......
...@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq) return http.DefaultTransport.RoundTrip(newReq)
} }
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() { func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir() s.tempDir = s.T().TempDir()
} }
...@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng ...@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100)) _, _ = w.Write(bytes.Repeat([]byte("a"), 100))
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file1.bin") dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
...@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { ...@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
} }
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file2.bin") dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
...@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { ...@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
} }
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file3.bin") dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
...@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { ...@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "notfound.bin") dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100) err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
...@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { ...@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum")) _, _ = w.Write([]byte("sum"))
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile") require.NoError(s.T(), err, "FetchChecksumFile")
...@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { ...@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) _, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200") require.Error(s.T(), err, "expected error for non-200")
...@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { ...@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done() <-r.Context().Done()
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
...@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { ...@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "invalid.bin") dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100) err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
...@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { ...@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content")) _, _ = w.Write([]byte("content"))
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
// Use a path that cannot be created (directory doesn't exist) // Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin") dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
...@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { ...@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url") _, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL") require.Error(s.T(), err, "expected error for invalid URL")
...@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ...@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo") release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { ...@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { ...@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { ...@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true,
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
...@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { ...@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done() <-r.Context().Done()
})) }))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient) s.client = newTestGitHubReleaseClient()
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
// 默认配置常量 // 默认配置常量
...@@ -30,9 +31,9 @@ const ( ...@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接 // 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240 defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟 // defaultIdleConnTimeout: 默认空闲连接超时时间(90秒
// 超时后连接会被关闭,释放系统资源 // 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
defaultIdleConnTimeout = 300 * time.Second defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时 // LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second defaultResponseHeaderTimeout = 300 * time.Second
...@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { ...@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏 // - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 // - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用 // 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil { if err != nil {
...@@ -145,6 +150,40 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i ...@@ -145,6 +150,40 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil return resp, nil
} }
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
if !s.cfg.Security.URLAllowlist.Enabled {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求 // acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰 // 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
...@@ -232,6 +271,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a ...@@ -232,6 +271,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return nil, fmt.Errorf("build transport: %w", err) return nil, fmt.Errorf("build transport: %w", err)
} }
client := &http.Client{Transport: transport} client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{ entry := &upstreamClientEntry{
client: client, client: client,
proxyKey: proxyKey, proxyKey: proxyKey,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment