Commit 195e227c authored by song's avatar song
Browse files

merge: 合并 upstream/main 并保留本地图片计费功能

parents 6fa704d6 752882a0
......@@ -76,6 +76,7 @@ func NewAccountHandler(
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
......@@ -91,6 +92,7 @@ type CreateAccountRequest struct {
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
......@@ -193,6 +195,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
......@@ -249,6 +252,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Notes: req.Notes,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
......@@ -357,7 +361,8 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
SyncProxies: syncProxies,
})
if err != nil {
response.ErrorFrom(c, err)
// Provide detailed error message for CRS sync failures
response.InternalError(c, "CRS sync failed: "+err.Error())
return
}
......
package admin
import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
......@@ -39,13 +43,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
......@@ -59,6 +63,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
})
}
......@@ -100,6 +106,10 @@ type UpdateSettingsRequest struct {
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
}
// UpdateSettings 更新系统设置
......@@ -111,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
......@@ -129,21 +145,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "Turnstile Site Key is required when enabled")
return
}
// 如果未提供 secret key,使用已保存的值(留空保留当前值)
if req.TurnstileSecretKey == "" {
if previousSettings.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
// 获取当前设置,检查参数是否有变化
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
req.TurnstileSecretKey = previousSettings.TurnstileSecretKey
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
siteKeyChanged := previousSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := previousSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err)
......@@ -178,6 +191,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
......@@ -185,6 +200,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
h.auditSettingsUpdate(c, previousSettings, settings, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
......@@ -198,13 +215,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
......@@ -218,9 +235,111 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
})
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
}
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
if before.SMTPPort != after.SMTPPort {
changed = append(changed, "smtp_port")
}
if before.SMTPUsername != after.SMTPUsername {
changed = append(changed, "smtp_username")
}
if req.SMTPPassword != "" {
changed = append(changed, "smtp_password")
}
if before.SMTPFrom != after.SMTPFrom {
changed = append(changed, "smtp_from_email")
}
if before.SMTPFromName != after.SMTPFromName {
changed = append(changed, "smtp_from_name")
}
if before.SMTPUseTLS != after.SMTPUseTLS {
changed = append(changed, "smtp_use_tls")
}
if before.TurnstileEnabled != after.TurnstileEnabled {
changed = append(changed, "turnstile_enabled")
}
if before.TurnstileSiteKey != after.TurnstileSiteKey {
changed = append(changed, "turnstile_site_key")
}
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
if before.SiteLogo != after.SiteLogo {
changed = append(changed, "site_logo")
}
if before.SiteSubtitle != after.SiteSubtitle {
changed = append(changed, "site_subtitle")
}
if before.APIBaseURL != after.APIBaseURL {
changed = append(changed, "api_base_url")
}
if before.ContactInfo != after.ContactInfo {
changed = append(changed, "contact_info")
}
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
changed = append(changed, "fallback_model_anthropic")
}
if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
changed = append(changed, "fallback_model_openai")
}
if before.FallbackModelGemini != after.FallbackModelGemini {
changed = append(changed, "fallback_model_gemini")
}
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
return changed
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
......
......@@ -109,6 +109,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
return &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
......
......@@ -8,14 +8,14 @@ type SystemSettings struct {
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPPasswordConfigured bool `json:"smtp_password_configured"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
......@@ -33,6 +33,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
}
type PublicSettings struct {
......
......@@ -62,6 +62,7 @@ type Group struct {
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
......
......@@ -11,8 +11,10 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -38,14 +40,19 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
}
}
......@@ -121,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
......@@ -128,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
......@@ -220,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
......@@ -344,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
......@@ -674,7 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
status, code, message := billingErrorDetails(err)
h.errorResponse(c, status, code, message)
return
}
......@@ -800,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
func billingErrorDetails(err error) (status int, code, message string) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
}
return http.StatusForbidden, "billing_error", msg
}
......@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -26,8 +27,8 @@ import (
const (
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// defaultPingInterval 流式响应等待时发送 ping 的默认间隔
defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
......@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
)
// ConcurrencyError represents a concurrency limit error with context
......@@ -63,14 +66,36 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
}
go func() {
<-ctx.Done()
wrapped()
}()
return wrapped
}
// IncrementWaitCount increments the wait count for a user
......@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
......
......@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
......@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
}
......@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流)
var result *service.ForwardResult
......
......@@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
}
}
......@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
......@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
......@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
......
......@@ -4,13 +4,34 @@ import (
"encoding/json"
"fmt"
"log"
"os"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type TransformOptions struct {
EnableIdentityPatch bool
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch string
}
func DefaultTransformOptions() TransformOptions {
return TransformOptions{
EnableIdentityPatch: true,
}
}
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
}
// TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为)
func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) {
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
......@@ -22,16 +43,24 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build contents: %w", err)
}
// 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForConfig := claudeReq
if strippedThinking {
// If we had to downgrade thinking blocks to plain text due to missing/invalid signatures,
// disable upstream thinking mode to avoid signature/structure validation errors.
reqCopy := *claudeReq
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
......@@ -75,12 +104,8 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
return json.Marshal(v1Req)
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiContent {
var parts []GeminiPart
// 注入身份防护指令
identityPatch := fmt.Sprintf(
func defaultIdentityPatch(modelName string) string {
return fmt.Sprintf(
"--- [IDENTITY_PATCH] ---\n"+
"Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+
"You are currently providing services as the native %s model via a standard API proxy.\n"+
......@@ -88,7 +113,20 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
"--- [SYSTEM_PROMPT_BEGIN] ---\n",
modelName,
)
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
var parts []GeminiPart
// 可选注入身份防护指令(身份补丁)
if opts.EnableIdentityPatch {
identityPatch := strings.TrimSpace(opts.IdentityPatch)
if identityPatch == "" {
identityPatch = defaultIdentityPatch(modelName)
}
parts = append(parts, GeminiPart{Text: identityPatch})
}
// 解析 system prompt
if len(system) > 0 {
......@@ -111,7 +149,13 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
}
}
// identity patch 模式下,用分隔符包裹 system prompt,便于上游识别/调试;关闭时尽量保持原始 system prompt。
if opts.EnableIdentityPatch && len(parts) > 0 {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 {
return nil
}
return &GeminiContent{
Role: "user",
......@@ -120,8 +164,9 @@ func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiCon
}
// buildContents 构建 contents
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, error) {
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, bool, error) {
var contents []GeminiContent
strippedThinking := false
for i, msg := range messages {
role := msg.Role
......@@ -129,9 +174,12 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
role = "model"
}
parts, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
parts, strippedThisMsg, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build parts for message %d: %w", i, err)
return nil, false, fmt.Errorf("build parts for message %d: %w", i, err)
}
if strippedThisMsg {
strippedThinking = true
}
// 只有 Gemini 模型支持 dummy thinking block workaround
......@@ -165,7 +213,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
})
}
return contents, nil
return contents, strippedThinking, nil
}
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
......@@ -174,8 +222,9 @@ const dummyThoughtSignature = "skip_thought_signature_validator"
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, bool, error) {
var parts []GeminiPart
strippedThinking := false
// 尝试解析为字符串
var textContent string
......@@ -183,13 +232,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
if textContent != "(no content)" && strings.TrimSpace(textContent) != "" {
parts = append(parts, GeminiPart{Text: strings.TrimSpace(textContent)})
}
return parts, nil
return parts, false, nil
}
// 解析为内容块数组
var blocks []ContentBlock
if err := json.Unmarshal(content, &blocks); err != nil {
return nil, fmt.Errorf("parse content blocks: %w", err)
return nil, false, fmt.Errorf("parse content blocks: %w", err)
}
for _, block := range blocks {
......@@ -208,8 +257,11 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
if strings.TrimSpace(block.Thinking) != "" {
parts = append(parts, GeminiPart{Text: block.Thinking})
}
strippedThinking = true
continue
} else {
// Gemini 模型使用 dummy signature
......@@ -276,7 +328,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
}
return parts, nil
return parts, strippedThinking, nil
}
// parseToolResultContent 解析 tool_result 的 content
......@@ -446,7 +498,7 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema)
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
......@@ -484,6 +536,56 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出(debug/test)
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
......@@ -546,13 +648,14 @@ var excludedSchemaKeys = map[string]bool{
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any) any {
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
......@@ -586,15 +689,15 @@ func cleanSchemaValue(value any) any {
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for _, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned
......
......@@ -15,15 +15,15 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
description string
}{
{
name: "Claude model - drop thinking without signature",
name: "Claude model - downgrade thinking to text without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // thinking 内容被丢弃
description: "Claude模型应丢弃无signaturethinking block内容",
expectedParts: 3, // thinking 内容降级为普通 text part
description: "Claude模型缺少signature时应将thinking降级为text,并在上层禁用thinking mode",
},
{
name: "Claude model - preserve thinking block with signature",
......@@ -52,7 +52,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
parts, _, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
......@@ -71,6 +71,17 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
case "Claude model - downgrade thinking to text without signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
if parts[1].Thought {
t.Fatalf("expected downgraded text part, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
if parts[1].Text != "Let me think..." {
t.Fatalf("expected downgraded text %q, got %q", "Let me think...", parts[1].Text)
}
case "Gemini model - use dummy signature":
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
......@@ -91,7 +102,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, true)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
......@@ -105,7 +116,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(content), toolIDToName, false)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, false)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
......
......@@ -25,13 +25,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
)
// Options 定义共享 HTTP 客户端的构建参数
......@@ -40,6 +41,9 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100)
......@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err
}
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
}
return &http.Client{
Transport: transport,
Transport: rt,
Timeout: opts.Timeout,
}, nil
}
......@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
type validatedTransport struct {
base http.RoundTripper
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
}
}
}
return t.base.RoundTrip(req)
}
......@@ -67,6 +67,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
builder := r.client.Account.Create().
SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
......@@ -270,6 +271,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
......@@ -320,6 +322,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} else {
builder.ClearSessionWindowStatus()
}
if account.Notes == nil {
builder.ClearNotes()
}
updated, err := builder.Save(ctx)
if err != nil {
......@@ -768,10 +773,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
idx++
}
if updates.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *updates.ProxyID == 0 {
setClauses = append(setClauses, "proxy_id = NULL")
} else {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
}
}
if updates.Concurrency != nil {
setClauses = append(setClauses, "concurrency = $"+itoa(idx))
args = append(args, *updates.Concurrency)
......@@ -1065,6 +1075,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return &service.Account{
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
......
......@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
......@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
......@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
......@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
......@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
......@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
......@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
......@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}
......@@ -16,6 +16,7 @@ const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
allowPrivateHosts bool
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
......@@ -26,6 +27,8 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
......
......@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
......@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
......@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
......@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
......
......@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
......
......@@ -15,17 +15,22 @@ import (
type githubReleaseClient struct {
httpClient *http.Client
allowPrivateHosts bool
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
}
}
......@@ -65,6 +70,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
......
......@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq)
}
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
......@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
......@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
......@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
......@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
......@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
......@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
......@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
......@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
......@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
......@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
......@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
......@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
......@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
......@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
......@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
......
......@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
......@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout = 300 * time.Second
// defaultIdleConnTimeout: 默认空闲连接超时时间(90秒
// 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
......@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
......@@ -145,6 +150,40 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
if !s.cfg.Security.URLAllowlist.Enabled {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
......@@ -232,6 +271,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return nil, fmt.Errorf("build transport: %w", err)
}
client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment