Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
...@@ -2,6 +2,7 @@ package handler ...@@ -2,6 +2,7 @@ package handler
import ( import (
"log/slog" "log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -112,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) { ...@@ -112,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) {
return return
} }
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) // Turnstile 验证 — 始终执行,防止绕过
if req.VerifyCode == "" { // TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
}
} }
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
...@@ -448,17 +448,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { ...@@ -448,17 +448,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return return
} }
// Build frontend base URL from request frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
scheme := "https" if frontendBaseURL == "" {
if c.Request.TLS == nil { slog.Error("server.frontend_url not configured; cannot build password reset link")
// Check X-Forwarded-Proto header (common in reverse proxy setups) response.InternalError(c, "Password reset is not configured")
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { return
scheme = proto
} else {
scheme = "http"
}
} }
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async) // Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration) // Note: This returns success even if email doesn't exist (to prevent enumeration)
......
package dto
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) {
lastUsed := time.Now().UTC().Truncate(time.Second)
src := &service.APIKey{
ID: 1,
UserID: 2,
Key: "sk-map-last-used",
Name: "Mapper",
Status: service.StatusActive,
LastUsedAt: &lastUsed,
}
out := APIKeyFromService(src)
require.NotNil(t, out)
require.NotNil(t, out.LastUsedAt)
require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second)
}
func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) {
src := &service.APIKey{
ID: 1,
UserID: 2,
Key: "sk-map-last-used-nil",
Name: "MapperNil",
Status: service.StatusActive,
}
out := APIKeyFromService(src)
require.NotNil(t, out)
require.Nil(t, out.LastUsedAt)
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package dto package dto
import ( import (
"strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -77,6 +78,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey { ...@@ -77,6 +78,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status: k.Status, Status: k.Status,
IPWhitelist: k.IPWhitelist, IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist, IPBlacklist: k.IPBlacklist,
LastUsedAt: k.LastUsedAt,
Quota: k.Quota, Quota: k.Quota,
QuotaUsed: k.QuotaUsed, QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt, ExpiresAt: k.ExpiresAt,
...@@ -129,23 +131,26 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { ...@@ -129,23 +131,26 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
func groupFromServiceBase(g *service.Group) Group { func groupFromServiceBase(g *service.Group) Group {
return Group{ return Group{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Description: g.Description, Description: g.Description,
Platform: g.Platform, Platform: g.Platform,
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD, DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly, SoraImagePrice360: g.SoraImagePrice360,
FallbackGroupID: g.FallbackGroupID, SoraImagePrice540: g.SoraImagePrice540,
// 无效请求兜底分组 SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
...@@ -300,6 +305,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi ...@@ -300,6 +305,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode: p.CountryCode, CountryCode: p.CountryCode,
Region: p.Region, Region: p.Region,
City: p.City, City: p.City,
QualityStatus: p.QualityStatus,
QualityScore: p.QualityScore,
QualityGrade: p.QualityGrade,
QualitySummary: p.QualitySummary,
QualityChecked: p.QualityChecked,
} }
} }
...@@ -404,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -404,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs, FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount, ImageCount: l.ImageCount,
ImageSize: l.ImageSize, ImageSize: l.ImageSize,
MediaType: l.MediaType,
UserAgent: l.UserAgent, UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden, CacheTTLOverridden: l.CacheTTLOverridden,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
...@@ -532,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult ...@@ -532,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
for i := range r.Subscriptions { for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
} }
statuses := make(map[string]string, len(r.Statuses))
for userID, status := range r.Statuses {
statuses[strconv.FormatInt(userID, 10)] = status
}
return &BulkAssignResult{ return &BulkAssignResult{
SuccessCount: r.SuccessCount, SuccessCount: r.SuccessCount,
CreatedCount: r.CreatedCount,
ReusedCount: r.ReusedCount,
FailedCount: r.FailedCount, FailedCount: r.FailedCount,
Subscriptions: subs, Subscriptions: subs,
Errors: r.Errors, Errors: r.Errors,
Statuses: statuses,
} }
} }
......
...@@ -38,6 +38,7 @@ type APIKey struct { ...@@ -38,6 +38,7 @@ type APIKey struct {
Status string `json:"status"` Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"` IPBlacklist []string `json:"ip_blacklist"`
LastUsedAt *time.Time `json:"last_used_at"`
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
...@@ -67,6 +68,12 @@ type Group struct { ...@@ -67,6 +68,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
...@@ -196,6 +203,11 @@ type ProxyWithAccountCount struct { ...@@ -196,6 +203,11 @@ type ProxyWithAccountCount struct {
CountryCode string `json:"country_code,omitempty"` CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"` Region string `json:"region,omitempty"`
City string `json:"city,omitempty"` City string `json:"city,omitempty"`
QualityStatus string `json:"quality_status,omitempty"`
QualityScore *int `json:"quality_score,omitempty"`
QualityGrade string `json:"quality_grade,omitempty"`
QualitySummary string `json:"quality_summary,omitempty"`
QualityChecked *int64 `json:"quality_checked,omitempty"`
} }
type ProxyAccountSummary struct { type ProxyAccountSummary struct {
...@@ -274,6 +286,7 @@ type UsageLog struct { ...@@ -274,6 +286,7 @@ type UsageLog struct {
// 图片生成字段 // 图片生成字段
ImageCount int `json:"image_count"` ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"` ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
// User-Agent // User-Agent
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`
...@@ -382,9 +395,12 @@ type AdminUserSubscription struct { ...@@ -382,9 +395,12 @@ type AdminUserSubscription struct {
type BulkAssignResult struct { type BulkAssignResult struct {
SuccessCount int `json:"success_count"` SuccessCount int `json:"success_count"`
CreatedCount int `json:"created_count"`
ReusedCount int `json:"reused_count"`
FailedCount int `json:"failed_count"` FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"` Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"` Errors []string `json:"errors"`
Statuses map[string]string `json:"statuses,omitempty"`
} }
// PromoCode 注册优惠码 // PromoCode 注册优惠码
......
...@@ -19,11 +19,13 @@ import ( ...@@ -19,11 +19,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap"
) )
// GatewayHandler handles API gateway requests // GatewayHandler handles API gateway requests
...@@ -35,10 +37,12 @@ type GatewayHandler struct { ...@@ -35,10 +37,12 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
cfg *config.Config
} }
// NewGatewayHandler creates a new GatewayHandler // NewGatewayHandler creates a new GatewayHandler
...@@ -51,6 +55,7 @@ func NewGatewayHandler( ...@@ -51,6 +55,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
...@@ -74,10 +79,12 @@ func NewGatewayHandler( ...@@ -74,10 +79,12 @@ func NewGatewayHandler(
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService, usageService: usageService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
} }
} }
...@@ -96,6 +103,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -96,6 +103,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
} }
reqLog := requestLogger(
c,
"handler.gateway.messages",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
...@@ -122,6 +136,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -122,6 +136,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
reqModel := parsedReq.Model reqModel := parsedReq.Model
reqStream := parsedReq.Stream reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
...@@ -161,9 +176,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -161,9 +176,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false waitCounted := false
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err))
// On error, allow request to proceed // On error, allow request to proceed
} else if !canWait { } else if !canWait {
reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return return
} }
...@@ -180,7 +196,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -180,7 +196,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 1. 首先获取用户并发槽位 // 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
return return
} }
...@@ -197,7 +213,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -197,7 +213,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅 // 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
...@@ -227,6 +243,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -227,6 +243,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var sessionBoundAccountID int64 var sessionBoundAccountID int64
if sessionKey != "" { if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx)
}
} }
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
...@@ -250,7 +275,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -250,7 +275,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
// Antigravity 单账号退避重试:分组内没有其他可用账号时, // Antigravity 单账号退避重试:分组内没有其他可用账号时,
...@@ -258,7 +284,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -258,7 +284,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{}) failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
...@@ -274,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -274,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截(预热请求、SUGGESTION MODE等) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() { if account.IsInterceptWarmupEnabled() {
...@@ -302,21 +331,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -302,21 +331,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
log.Printf("Increment account wait count failed: %v", err) reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait { } else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID) reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return return
} }
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
// Ensure the wait counter is decremented if we exit before acquiring the slot. releaseWait := func() {
defer func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
} }
}() }
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -327,17 +359,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -327,17 +359,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted, &streamStarted,
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
// Slot acquired: no longer waiting in queue. // Slot acquired: no longer waiting in queue.
if accountWaitCounted { releaseWait()
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
...@@ -387,7 +417,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -387,7 +417,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
switchCount++ switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) { if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return return
...@@ -395,8 +430,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -395,8 +430,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
continue continue
} }
// 错误响应已在Forward中处理,这里只记录日志 wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
log.Printf("Forward request failed: %v", err) reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return return
} }
...@@ -404,24 +443,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -404,24 +443,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 异步记录使用量(subscription已在函数开头获取) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
ForceCacheBilling: fcb, ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP, forceCacheBilling) })
return return
} }
} }
...@@ -455,7 +499,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -455,7 +499,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
// Antigravity 单账号退避重试:分组内没有其他可用账号时, // Antigravity 单账号退避重试:分组内没有其他可用账号时,
...@@ -463,7 +508,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -463,7 +508,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{}) failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
...@@ -479,7 +527,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -479,7 +527,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截(预热请求、SUGGESTION MODE等) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() { if account.IsInterceptWarmupEnabled() {
...@@ -507,20 +555,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -507,20 +555,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
log.Printf("Increment account wait count failed: %v", err) reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait { } else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID) reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return return
} }
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
defer func() { releaseWait := func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
} }
}() }
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -531,16 +583,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -531,16 +583,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted, &streamStarted,
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountWaitCounted { // Slot acquired: no longer waiting in queue.
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) releaseWait()
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
...@@ -563,18 +614,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -563,18 +614,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil { if err != nil {
var promptTooLongErr *service.PromptTooLongError var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) { if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) reqLog.Warn("gateway.prompt_too_long_from_antigravity",
zap.Any("current_group_id", currentAPIKey.GroupID),
zap.Any("fallback_group_id", fallbackGroupID),
zap.Bool("fallback_used", fallbackUsed),
)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil { if err != nil {
log.Printf("Resolve fallback group failed: %v", err) reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err))
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return return
} }
if fallbackGroup.Platform != service.PlatformAnthropic || if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) reqLog.Warn("gateway.fallback_group_invalid",
zap.Int64("fallback_group_id", fallbackGroup.ID),
zap.String("fallback_platform", fallbackGroup.Platform),
zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType),
)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return return
} }
...@@ -625,7 +684,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -625,7 +684,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
switchCount++ switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) { if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return return
...@@ -633,8 +697,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -633,8 +697,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
continue continue
} }
// 错误响应已在Forward中处理,这里只记录日志 wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
log.Printf("Account %d: Forward request failed: %v", account.ID, err) reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return return
} }
...@@ -642,24 +710,34 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -642,24 +710,34 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 异步记录使用量(subscription已在函数开头获取) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: currentAPIKey, APIKey: currentAPIKey,
User: currentAPIKey.User, User: currentAPIKey.User,
Account: usedAccount, Account: account,
Subscription: currentSubscription, Subscription: currentSubscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
ForceCacheBilling: fcb, ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", currentAPIKey.ID),
zap.Any("group_id", currentAPIKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP, forceCacheBilling) })
reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Bool("fallback_used", fallbackUsed),
)
return return
} }
if !retryWithFallback { if !retryWithFallback {
...@@ -682,6 +760,17 @@ func (h *GatewayHandler) Models(c *gin.Context) { ...@@ -682,6 +760,17 @@ func (h *GatewayHandler) Models(c *gin.Context) {
groupID = &apiKey.Group.ID groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" {
platform = forcedPlatform
}
if platform == service.PlatformSora {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": service.DefaultSoraModels(h.cfg),
})
return
}
// Get available models from account configurations (without platform filter) // Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
...@@ -942,7 +1031,11 @@ func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) b ...@@ -942,7 +1031,11 @@ func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) b
// Handler 层只需短暂间隔后重新进入 Service 层即可。 // Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second const delay = 2 * time.Second
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount) logger.L().With(
zap.String("component", "handler.gateway.failover"),
zap.Duration("delay", delay),
zap.Int("retry_count", retryCount),
).Info("gateway.single_account_backoff_waiting")
select { select {
case <-ctx.Done(): case <-ctx.Done():
...@@ -1040,6 +1133,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e ...@@ -1040,6 +1133,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
h.errorResponse(c, status, errType, message) h.errorResponse(c, status, errType, message)
} }
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse 返回Claude API格式的错误响应 // errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{
...@@ -1067,6 +1169,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -1067,6 +1169,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
} }
reqLog := requestLogger(
c,
"handler.gateway.count_tokens",
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
...@@ -1094,6 +1202,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -1094,6 +1202,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
...@@ -1127,14 +1236,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -1127,14 +1236,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return return
} }
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID, account.Platform)
// 转发请求(不记录使用量) // 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err) reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
// 错误响应已在 ForwardCountTokens 中处理 // 错误响应已在 ForwardCountTokens 中处理
return return
} }
...@@ -1398,7 +1508,25 @@ func billingErrorDetails(err error) (status int, code, message string) { ...@@ -1398,7 +1508,25 @@ func billingErrorDetails(err error) (status int, code, message string) {
} }
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
if msg == "" { if msg == "" {
msg = err.Error() logger.L().With(
zap.String("component", "handler.gateway.billing"),
zap.Error(err),
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
} }
return http.StatusForbidden, "billing_error", msg return http.StatusForbidden, "billing_error", msg
} }
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
assert.Equal(t, "error", parsed["type"])
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
...@@ -4,8 +4,9 @@ import ( ...@@ -4,8 +4,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
...@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator() ...@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context // 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) { func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 解析请求体为 map if c == nil || c.Request == nil {
var bodyMap map[string]any return
if len(body) > 0 { }
_ = json.Unmarshal(body, &bodyMap) // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
} }
// 验证是否为 Claude Code 客户端 isClaudeCode := false
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) if !strings.Contains(c.Request.URL.Path, "messages") {
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
isClaudeCode = true
} else {
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
}
// 更新 request context // 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
...@@ -104,31 +119,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo ...@@ -104,31 +119,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。 // 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 // 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil { if releaseFunc == nil {
return nil return nil
} }
var once sync.Once var once sync.Once
quit := make(chan struct{}) var stop func() bool
release := func() { release := func() {
once.Do(func() { once.Do(func() {
if stop != nil {
_ = stop()
}
releaseFunc() releaseFunc()
close(quit) // 通知监听 goroutine 退出
}) })
} }
go func() { stop = context.AfterFunc(ctx, release)
select {
case <-ctx.Done():
// Context 取消时释放资源
release()
case <-quit:
// 正常释放已完成,goroutine 退出
return
}
}()
return release return release
} }
...@@ -153,6 +161,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou ...@@ -153,6 +161,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
} }
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
...@@ -160,13 +194,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64 ...@@ -160,13 +194,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if result.Acquired { if acquired {
return result.ReleaseFunc, nil return releaseFunc, nil
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
...@@ -180,13 +214,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ...@@ -180,13 +214,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if result.Acquired { if acquired {
return result.ReleaseFunc, nil return releaseFunc, nil
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
...@@ -196,27 +230,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ...@@ -196,27 +230,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller). // streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
} }
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. // waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel() defer cancel()
// Try immediate acquire first (avoid unnecessary wait) acquireSlot := func() (*service.AcquireResult, error) {
var result *service.AcquireResult if slotType == "user" {
var err error return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
if slotType == "user" { }
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
} }
if result.Acquired {
return result.ReleaseFunc, nil if tryImmediate {
result, err := acquireSlot()
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
} }
// Determine if ping is needed (streaming + ping format defined) // Determine if ping is needed (streaming + ping format defined)
...@@ -242,7 +278,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -242,7 +278,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
backoff := initialBackoff backoff := initialBackoff
timer := time.NewTimer(backoff) timer := time.NewTimer(backoff)
defer timer.Stop() defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for { for {
select { select {
...@@ -268,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -268,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
case <-timer.C: case <-timer.C:
// Try to acquire slot // Try to acquire slot
var result *service.AcquireResult result, err := acquireSlot()
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -284,7 +311,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -284,7 +311,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
if result.Acquired { if result.Acquired {
return result.ReleaseFunc, nil return result.ReleaseFunc, nil
} }
backoff = nextBackoff(backoff, rng) backoff = nextBackoff(backoff)
timer.Reset(backoff) timer.Reset(backoff)
} }
} }
...@@ -292,26 +319,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -292,26 +319,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). // AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
} }
// nextBackoff 计算下一次退避时间 // nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间 // current: 当前退避时间
// rng: 随机数生成器(可为 nil,此时不添加抖动)
// 返回值:下一次退避时间(100ms ~ 2s 之间) // 返回值:下一次退避时间(100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { func nextBackoff(current time.Duration) time.Duration {
// 指数退避:当前时间 * 1.5 // 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier) next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff { if next > maxBackoff {
next = maxBackoff next = maxBackoff
} }
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4 jitter := 0.8 + rand.Float64()*0.4
jittered := time.Duration(float64(next) * jitter) jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff { if jittered < initialBackoff {
return initialBackoff return initialBackoff
......
package handler
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
// 验证退避时间指数增长(乘数 1.5)
// 由于有随机抖动(±20%),需要验证范围
current := initialBackoff // 100ms
for i := 0; i < 10; i++ {
next := nextBackoff(current)
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
// 为下一轮提供当前退避值
current = next
}
}
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
// 即使输入非常大,输出也不超过 maxBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(10 * time.Second)
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
"退避值不应超过 maxBackoff")
}
}
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
// 即使输入非常小,输出也不低于 initialBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(1 * time.Millisecond)
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
"退避值不应低于 initialBackoff")
}
}
func TestNextBackoff_HasJitter(t *testing.T) {
// 验证多次调用会产生不同的值(随机抖动生效)
// 使用相同的输入调用 50 次,收集结果
results := make(map[time.Duration]bool)
current := 500 * time.Millisecond
for i := 0; i < 50; i++ {
result := nextBackoff(current)
results[result] = true
}
// 50 次调用应该至少有 2 个不同的值(抖动存在)
require.Greater(t, len(results), 1,
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
}
func TestNextBackoff_InitialValueGrows(t *testing.T) {
// 验证从初始值开始,退避趋势是增长的
current := initialBackoff
var sum time.Duration
runs := 100
for i := 0; i < runs; i++ {
next := nextBackoff(current)
sum += next
current = next
}
avg := sum / time.Duration(runs)
// 平均退避时间应大于初始值(因为指数增长 + 上限)
assert.Greater(t, int64(avg), int64(initialBackoff),
"平均退避时间应大于初始退避值")
}
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
current := initialBackoff
for i := 0; i < 20; i++ {
current = nextBackoff(current)
}
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
// 由于抖动,允许 ±20% 的范围
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
"经过多次退避后应收敛到 maxBackoff 附近")
}
func BenchmarkNextBackoff(b *testing.B) {
current := initialBackoff
for i := 0; i < b.N; i++ {
current = nextBackoff(current)
if current > maxBackoff {
current = initialBackoff
}
}
}
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
type concurrencyCacheMock struct {
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
releaseUserCalled int32
releaseAccountCalled int32
}
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireAccountSlotFn != nil {
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
atomic.AddInt32(&m.releaseAccountCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireUserSlotFn != nil {
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
atomic.AddInt32(&m.releaseUserCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
require.NoError(t, err)
require.True(t, acquired)
require.NotNil(t, release)
release()
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
}
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
cache := &concurrencyCacheMock{
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
require.NoError(t, err)
require.False(t, acquired)
require.Nil(t, release)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
}
package handler
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type helperConcurrencyCacheStub struct {
mu sync.Mutex
accountSeq []bool
userSeq []bool
accountAcquireCalls int
userAcquireCalls int
accountReleaseCalls int
userReleaseCalls int
}
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.accountAcquireCalls++
if len(s.accountSeq) == 0 {
return false, nil
}
v := s.accountSeq[0]
s.accountSeq = s.accountSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.accountReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.userAcquireCalls++
if len(s.userSeq) == 0 {
return false, nil
}
v := s.userSeq[0]
s.userSeq = s.userSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.userReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
out := make(map[int64]*service.UserLoadInfo, len(users))
for _, user := range users {
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(method, path, nil)
return c, rec
}
func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
}`)
}
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "curl/8.6.0")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
SetClaudeCodeClientContext(c, nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
// 缺少严格校验所需 header + body 字段
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, true},
userSeq: []bool{false, true},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
require.False(t, streamStarted)
release()
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
})
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
release()
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
})
}
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, false, false},
}
concurrency := service.NewConcurrencyService(cache)
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
})
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.True(t, streamStarted)
require.Contains(t, rec.Body.String(), ":\n\n")
})
}
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
errCache := &helperConcurrencyCacheStubWithError{
err: errors.New("redis unavailable"),
}
concurrency := service.NewConcurrencyService(errCache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
require.Error(t, err)
require.Contains(t, err.Error(), "redis unavailable")
}
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
}
type helperConcurrencyCacheStubWithError struct {
helperConcurrencyCacheStub
err error
}
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, s.err
}
...@@ -8,11 +8,9 @@ import ( ...@@ -8,11 +8,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"log"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...@@ -20,11 +18,13 @@ import ( ...@@ -20,11 +18,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap"
) )
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值 // geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
...@@ -143,6 +143,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -143,6 +143,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusInternalServerError, "User context not found") googleError(c, http.StatusInternalServerError, "User context not found")
return return
} }
reqLog := requestLogger(
c,
"handler.gemini_v1beta.models",
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组 // 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
if !middleware.HasForcePlatform(c) { if !middleware.HasForcePlatform(c) {
...@@ -159,6 +166,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -159,6 +166,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
stream := action == "streamGenerateContent" stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
...@@ -187,8 +195,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -187,8 +195,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait) canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
waitCounted := false waitCounted := false
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait { } else if !canWait {
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return return
} }
...@@ -208,6 +217,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -208,6 +217,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil { if err != nil {
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
...@@ -223,6 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -223,6 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait) // 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err) status, _, message := billingErrorDetails(err)
googleError(c, status, message) googleError(c, status, message)
return return
...@@ -252,6 +263,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -252,6 +263,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var sessionBoundAccountID int64 var sessionBoundAccountID int64
if sessionKey != "" { if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx)
}
} }
// === Gemini 内容摘要会话 Fallback 逻辑 === // === Gemini 内容摘要会话 Fallback 逻辑 ===
...@@ -296,8 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -296,8 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
matchedDigestChain = foundMatchedChain matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", reqLog.Info("gemini.digest_fallback_matched",
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain)) zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
zap.Int64("account_id", foundAccountID),
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
)
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey // 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号 // 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
...@@ -346,7 +369,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -346,7 +369,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) reqLog.Warn("gemini.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{}) failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
...@@ -358,18 +384,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -358,18 +384,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID, account.Platform)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) reqLog.Info("gemini.sticky_session_account_switched",
zap.Int64("from_account_id", sessionBoundAccountID),
zap.Int64("to_account_id", account.ID),
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body) body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { } else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。 // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively") reqLog.Info("gemini.sticky_session_binding_missing",
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body) body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID sessionBoundAccountID = account.ID
...@@ -388,9 +420,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -388,9 +420,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted := false accountWaitCounted := false
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
log.Printf("Increment account wait count failed: %v", err) reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait { } else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID) reqLog.Info("gemini.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return return
} }
...@@ -412,6 +447,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -412,6 +447,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&streamStarted, &streamStarted,
) )
if err != nil { if err != nil {
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
...@@ -420,7 +456,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -420,7 +456,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted = false accountWaitCounted = false
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
...@@ -454,7 +490,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -454,7 +490,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
lastFailoverErr = failoverErr lastFailoverErr = failoverErr
switchCount++ switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) reqLog.Warn("gemini.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) { if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return return
...@@ -463,7 +504,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -463,7 +504,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
continue continue
} }
// ForwardNative already wrote the response // ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err) reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return return
} }
...@@ -482,31 +523,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -482,31 +523,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account.ID, account.ID,
matchedDigestChain, matchedDigestChain,
); err != nil { ); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err) reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} }
} }
// 6) record usage async (Gemini 使用长上下文双倍计费) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: ip, IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb, ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName),
zap.Int64("account_id", account.ID),
).Error("gemini.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP, forceCacheBilling) })
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return return
} }
} }
......
...@@ -39,6 +39,7 @@ type Handlers struct { ...@@ -39,6 +39,7 @@ type Handlers struct {
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler Setting *SettingHandler
Totp *TotpHandler Totp *TotpHandler
} }
......
package handler
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func executeUserIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
return
}
actorScope := "user:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
}
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
package handler
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userStoreUnavailableRepoStub struct{}
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
type userMemoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
return &userMemoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func withUserSubject(userID int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
}
}
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(nil)
var executed int
router := gin.New()
router.Use(withUserSubject(1))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, executed)
}
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.Use(withUserSubject(2))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "k1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed)
}
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newUserMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.Use(withUserSubject(3))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(80 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-user-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() { defer wg.Done(); status1, _ = call() }()
go func() { defer wg.Done(); status2, _ = call() }()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load())
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
base := logger.L()
if c != nil && c.Request != nil {
base = logger.FromContext(c.Request.Context())
}
if component != "" {
fields = append([]zap.Field{zap.String("component", component)}, fields...)
}
return base.With(fields...)
}
...@@ -6,18 +6,19 @@ import ( ...@@ -6,18 +6,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
) )
// OpenAIGatewayHandler handles OpenAI API gateway requests // OpenAIGatewayHandler handles OpenAI API gateway requests
...@@ -25,6 +26,7 @@ type OpenAIGatewayHandler struct { ...@@ -25,6 +26,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
...@@ -36,6 +38,7 @@ func NewOpenAIGatewayHandler( ...@@ -36,6 +38,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
...@@ -51,6 +54,7 @@ func NewOpenAIGatewayHandler( ...@@ -51,6 +54,7 @@ func NewOpenAIGatewayHandler(
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
...@@ -60,6 +64,8 @@ func NewOpenAIGatewayHandler( ...@@ -60,6 +64,8 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint // Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses // POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware) // Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok { if !ok {
...@@ -72,6 +78,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -72,6 +78,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
} }
reqLog := requestLogger(
c,
"handler.openai_gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body // Read request body
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
...@@ -91,57 +104,57 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -91,57 +104,57 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
// Parse request body to map for potential modification // 校验请求体 JSON 合法性
var reqBody map[string]any if !gjson.ValidBytes(body) {
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
// Extract model and stream // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
reqModel, _ := reqBody["model"].(string) modelResult := gjson.GetBytes(body, "model")
reqStream, _ := reqBody["stream"].(bool) if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return return
} }
reqModel := modelResult.String()
userAgent := c.GetHeader("User-Agent") streamResult := gjson.GetBytes(body, "stream")
if !openai.IsCodexCLIRequest(userAgent) { if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
existingInstructions, _ := reqBody["instructions"].(string) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
if strings.TrimSpace(existingInstructions) == "" { return
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
}
} }
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
// 或带 id 且与 call_id 匹配的 item_reference。 // 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) { // 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
previousResponseID, _ := reqBody["previous_response_id"].(string) if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { var reqBody map[string]any
if service.HasFunctionCallOutputMissingCallID(reqBody) { if err := json.Unmarshal(body, &reqBody); err == nil {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel) c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") if service.HasFunctionCallOutput(reqBody) {
return previousResponseID, _ := reqBody["previous_response_id"].(string)
} if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
callIDs := service.FunctionCallOutputCallIDs(reqBody) if service.HasFunctionCallOutputMissingCallID(reqBody) {
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { reqLog.Warn("openai.request_validation_failed",
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel) zap.String("reason", "function_call_output_missing_call_id"),
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") )
return h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
} }
} }
} }
...@@ -157,34 +170,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -157,34 +170,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get subscription info (may be nil) // Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
maxWait := service.CalculateMaxWait(subject.Concurrency) routingStart := time.Now()
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false // 0. 先尝试直接抢占用户槽位(快速路径)
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
// On error, allow request to proceed h.handleConcurrencyError(c, err, "user", streamStarted)
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return return
} }
if err == nil && canWait {
waitCounted = true waitCounted := false
} if !userAcquired {
defer func() { // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
if waitCounted { maxWait := service.CalculateMaxWait(subject.Concurrency)
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if waitErr != nil {
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if waitErr == nil && canWait {
waitCounted = true
} }
}() defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
// 1. First acquire user concurrency slot userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil {
if err != nil { reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
log.Printf("User concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "user", streamStarted)
h.handleConcurrencyError(c, err, "user", streamStarted) return
return }
} }
// User slot acquired: no longer waiting.
// 用户槽位已获取:退出等待队列计数。
if waitCounted { if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false waitCounted = false
...@@ -197,14 +224,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -197,14 +224,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err) status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted) h.handleStreamingAwareError(c, status, code, message, streamStarted)
return return
} }
// Generate session hash (header first; fallback to prompt_cache_key) // Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) sessionHash := h.gatewayService.GenerateSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
...@@ -213,12 +240,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -213,12 +240,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for { for {
// Select account supporting the requested model // Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) reqLog.Warn("openai.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
if lastFailoverErr != nil { if lastFailoverErr != nil {
...@@ -229,8 +259,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -229,8 +259,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
account := selection.Account account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID, account.Platform)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc := selection.ReleaseFunc
...@@ -239,53 +269,87 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -239,53 +269,87 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return return
} }
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( // 先快速尝试一次账号槽位,命中则跳过等待计数写入。
c, fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
c.Request.Context(),
account.ID, account.ID,
selection.WaitPlan.MaxConcurrency, selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountWaitCounted { if fastAcquired {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) accountReleaseFunc = fastReleaseFunc
accountWaitCounted = false if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
} reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { }
log.Printf("Bind sticky session failed: %v", err) } else {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request // Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
...@@ -296,11 +360,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -296,11 +360,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
switchCount++ switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue continue
} }
// Error response already handled in Forward, just log wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
log.Printf("Account %d: Forward request failed: %v", account.ID, err) reqLog.Error("openai.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return return
} }
...@@ -308,27 +381,72 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -308,27 +381,72 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// Async record usage // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: ip, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP) })
reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return return
} }
} }
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
}
v, ok := c.Get(key)
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case float64:
return int64(t), true
default:
return 0, false
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response // handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
...@@ -397,8 +515,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status ...@@ -397,8 +515,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// Stream already started, send error as SSE event then close // Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if ok { if ok {
// Send error event in OpenAI SSE format // Send error event in OpenAI SSE format with proper JSON marshaling
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err) _ = c.Error(err)
} }
...@@ -411,6 +540,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status ...@@ -411,6 +540,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
h.errorResponse(c, status, errType, message) h.errorResponse(c, status, errType, message)
} }
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse returns OpenAI API format error response // errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{
......
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号的消息",
errType: "server_error",
message: `upstream returned "invalid" response`,
},
{
name: "包含反斜杠的消息",
errType: "server_error",
message: `path C:\Users\test\file.txt not found`,
},
{
name: "包含双引号和反斜杠的消息",
errType: "upstream_error",
message: `error parsing "key\value": unexpected token`,
},
{
name: "包含换行符的消息",
errType: "server_error",
message: "line1\nline2\ttab",
},
{
name: "普通消息",
errType: "upstream_error",
message: "Upstream service temporarily unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
// 验证 SSE 格式:event: error\ndata: {JSON}\n\n
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
// 提取 data 部分
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "应有 event 行和 data 行")
dataLine := lines[1]
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
jsonStr := strings.TrimPrefix(dataLine, "data: ")
// 验证 JSON 合法性
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
// 验证结构
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "应包含 error 对象")
assert.Equal(t, tt.errType, errorObj["type"])
assert.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
// 非流式应返回 JSON 响应
assert.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "test error", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantStream bool
}{
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
{"model 缺失", `{"stream":true}`, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := []byte(tt.body)
modelResult := gjson.GetBytes(body, "model")
model := ""
if modelResult.Type == gjson.String {
model = modelResult.String()
}
stream := gjson.GetBytes(body, "stream").Bool()
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
})
}
}
// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验
func TestOpenAIHandler_GjsonValidation(t *testing.T) {
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid json`)))
// model 为数字 → 类型不是 gjson.String,应被拒绝
body := []byte(`{"model":123}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, modelResult.Exists())
require.NotEqual(t, gjson.String, modelResult.Type)
// model 为 null → 类型不是 gjson.String,应被拒绝
body2 := []byte(`{"model":null}`)
modelResult2 := gjson.GetBytes(body2, "model")
require.True(t, modelResult2.Exists())
require.NotEqual(t, gjson.String, modelResult2.Type)
// stream 为 string → 类型既不是 True 也不是 False,应被拒绝
body3 := []byte(`{"model":"gpt-4","stream":"true"}`)
streamResult := gjson.GetBytes(body3, "stream")
require.True(t, streamResult.Exists())
require.NotEqual(t, gjson.True, streamResult.Type)
require.NotEqual(t, gjson.False, streamResult.Type)
// stream 为 int → 同上
body4 := []byte(`{"model":"gpt-4","stream":1}`)
streamResult2 := gjson.GetBytes(body4, "stream")
require.True(t, streamResult2.Exists())
require.NotEqual(t, gjson.True, streamResult2.Type)
require.NotEqual(t, gjson.False, streamResult2.Type)
}
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
// 测试 1:无 instructions → 注入
body := []byte(`{"model":"gpt-4"}`)
existing := gjson.GetBytes(body, "instructions").String()
require.Empty(t, existing)
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
require.NoError(t, err)
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
// 测试 2:已有 instructions → 不覆盖
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
existing2 := gjson.GetBytes(body2, "instructions").String()
require.Equal(t, "existing", existing2)
// 测试 3:空白 instructions → 注入
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
require.Empty(t, existing3)
// 测试 4:sjson.SetBytes 返回错误时不应 panic
// 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理
validBody := []byte(`{"model":"gpt-4"}`)
result, setErr := sjson.SetBytes(validBody, "instructions", "hello")
require.NoError(t, setErr)
require.True(t, gjson.ValidBytes(result))
}
...@@ -41,9 +41,8 @@ const ( ...@@ -41,9 +41,8 @@ const (
) )
type opsErrorLogJob struct { type opsErrorLogJob struct {
ops *service.OpsService ops *service.OpsService
entry *service.OpsInsertErrorLogInput entry *service.OpsInsertErrorLogInput
requestBody []byte
} }
var ( var (
...@@ -58,6 +57,7 @@ var ( ...@@ -58,6 +57,7 @@ var (
opsErrorLogEnqueued atomic.Int64 opsErrorLogEnqueued atomic.Int64
opsErrorLogDropped atomic.Int64 opsErrorLogDropped atomic.Int64
opsErrorLogProcessed atomic.Int64 opsErrorLogProcessed atomic.Int64
opsErrorLogSanitized atomic.Int64
opsErrorLogLastDropLogAt atomic.Int64 opsErrorLogLastDropLogAt atomic.Int64
...@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() { ...@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
} }
}() }()
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry, job.requestBody) _ = job.ops.RecordError(ctx, job.entry, nil)
cancel() cancel()
opsErrorLogProcessed.Add(1) opsErrorLogProcessed.Add(1)
}() }()
...@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() { ...@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
} }
} }
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) { func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
if ops == nil || entry == nil { if ops == nil || entry == nil {
return return
} }
...@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo ...@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
} }
select { select {
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}: case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
opsErrorLogQueueLen.Add(1) opsErrorLogQueueLen.Add(1)
opsErrorLogEnqueued.Add(1) opsErrorLogEnqueued.Add(1)
default: default:
...@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 { ...@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
return opsErrorLogProcessed.Load() return opsErrorLogProcessed.Load()
} }
func OpsErrorLogSanitizedTotal() int64 {
return opsErrorLogSanitized.Load()
}
func maybeLogOpsErrorLogDrop() { func maybeLogOpsErrorLogDrop() {
now := time.Now().Unix() now := time.Now().Unix()
...@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() { ...@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
queueCap := OpsErrorLogQueueCapacity() queueCap := OpsErrorLogQueueCapacity()
log.Printf( log.Printf(
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)", "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
queued, queued,
queueCap, queueCap,
opsErrorLogEnqueued.Load(), opsErrorLogEnqueued.Load(),
opsErrorLogDropped.Load(), opsErrorLogDropped.Load(),
opsErrorLogProcessed.Load(), opsErrorLogProcessed.Load(),
opsErrorLogSanitized.Load(),
) )
} }
...@@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody ...@@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
if c == nil { if c == nil {
return return
} }
model = strings.TrimSpace(model)
c.Set(opsModelKey, model) c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream) c.Set(opsStreamKey, stream)
if len(requestBody) > 0 { if len(requestBody) > 0 {
c.Set(opsRequestBodyKey, requestBody) c.Set(opsRequestBodyKey, requestBody)
} }
if c.Request != nil && model != "" {
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
c.Request = c.Request.WithContext(ctx)
}
} }
func setOpsSelectedAccount(c *gin.Context, accountID int64) { func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
v, ok := c.Get(opsRequestBodyKey)
if !ok {
return
}
raw, ok := v.([]byte)
if !ok || len(raw) == 0 {
return
}
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
opsErrorLogSanitized.Add(1)
}
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
if c == nil || accountID <= 0 { if c == nil || accountID <= 0 {
return return
} }
c.Set(opsAccountIDKey, accountID) c.Set(opsAccountIDKey, accountID)
if c.Request != nil {
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
if len(platform) > 0 {
p := strings.TrimSpace(platform[0])
if p != "" {
ctx = context.WithValue(ctx, ctxkey.Platform, p)
}
}
c.Request = c.Request.WithContext(ctx)
}
} }
type opsCaptureWriter struct { type opsCaptureWriter struct {
...@@ -507,6 +543,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -507,6 +543,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0, RetryCount: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
applyOpsLatencyFieldsFromContext(c, entry)
if apiKey != nil { if apiKey != nil {
entry.APIKeyID = &apiKey.ID entry.APIKeyID = &apiKey.ID
...@@ -528,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -528,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP entry.ClientIP = &clientIP
} }
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Store request headers/body only when an upstream error occurred to keep overhead minimal. // Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
// Skip logging if a passthrough rule with skip_monitoring=true matched. // Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
...@@ -544,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -544,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
} }
enqueueOpsErrorLog(ops, entry, requestBody) enqueueOpsErrorLog(ops, entry)
return return
} }
...@@ -632,6 +664,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -632,6 +664,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0, RetryCount: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
applyOpsLatencyFieldsFromContext(c, entry)
// Capture upstream error context set by gateway services (if present). // Capture upstream error context set by gateway services (if present).
// This does NOT affect the client response; it enriches Ops troubleshooting data. // This does NOT affect the client response; it enriches Ops troubleshooting data.
...@@ -707,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -707,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP entry.ClientIP = &clientIP
} }
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity. // Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Do NOT store Authorization/Cookie/etc. // Do NOT store Authorization/Cookie/etc.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
enqueueOpsErrorLog(ops, entry, requestBody) enqueueOpsErrorLog(ops, entry)
} }
} }
...@@ -760,6 +788,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string { ...@@ -760,6 +788,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
return &s return &s
} }
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
}
func getContextLatencyMs(c *gin.Context, key string) *int64 {
if c == nil || strings.TrimSpace(key) == "" {
return nil
}
v, ok := c.Get(key)
if !ok {
return nil
}
var ms int64
switch t := v.(type) {
case int:
ms = int64(t)
case int32:
ms = int64(t)
case int64:
ms = t
case float64:
ms = int64(t)
default:
return nil
}
if ms < 0 {
return nil
}
return &ms
}
type parsedOpsError struct { type parsedOpsError struct {
ErrorType string ErrorType string
Message string Message string
......
package handler
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func resetOpsErrorLoggerStateForTest(t *testing.T) {
t.Helper()
opsErrorLogMu.Lock()
ch := opsErrorLogQueue
opsErrorLogQueue = nil
opsErrorLogStopping = true
opsErrorLogMu.Unlock()
if ch != nil {
close(ch)
}
opsErrorLogWorkersWg.Wait()
opsErrorLogOnce = sync.Once{}
opsErrorLogStopOnce = sync.Once{}
opsErrorLogWorkersWg = sync.WaitGroup{}
opsErrorLogMu = sync.RWMutex{}
opsErrorLogStopping = false
opsErrorLogQueueLen.Store(0)
opsErrorLogEnqueued.Store(0)
opsErrorLogDropped.Store(0)
opsErrorLogProcessed.Store(0)
opsErrorLogSanitized.Store(0)
opsErrorLogLastDropLogAt.Store(0)
opsErrorLogShutdownCh = make(chan struct{})
opsErrorLogShutdownOnce = sync.Once{}
opsErrorLogDrained.Store(false)
}
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
setOpsRequestContext(c, "claude-3", false, raw)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(c, entry)
require.NotNil(t, entry.RequestBodyBytes)
require.Equal(t, len(raw), *entry.RequestBodyBytes)
require.NotNil(t, entry.RequestBodyJSON)
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
}
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
raw := []byte("not-json")
setOpsRequestContext(c, "claude-3", false, raw)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.NotNil(t, entry.RequestBodyBytes)
require.Equal(t, len(raw), *entry.RequestBodyBytes)
require.False(t, entry.RequestBodyTruncated)
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
}
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
opsErrorLogOnce.Do(func() {})
opsErrorLogMu.Lock()
opsErrorLogQueue = make(chan opsErrorLogJob, 1)
opsErrorLogMu.Unlock()
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
enqueueOpsErrorLog(ops, entry)
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal())
require.Equal(t, int64(1), OpsErrorLogDroppedTotal())
require.Equal(t, int64(1), OpsErrorLogQueueLength())
}
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(nil, entry)
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 无请求体 key
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
require.False(t, entry.RequestBodyTruncated)
// 错误类型
c.Set(opsRequestBodyKey, "not-bytes")
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
// 空 bytes
c.Set(opsRequestBodyKey, []byte{})
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
}
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
// nil 入参分支
enqueueOpsErrorLog(nil, entry)
enqueueOpsErrorLog(ops, nil)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
// shutdown 分支
close(opsErrorLogShutdownCh)
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
// stopping 分支
resetOpsErrorLoggerStateForTest(t)
opsErrorLogMu.Lock()
opsErrorLogStopping = true
opsErrorLogMu.Unlock()
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
// queue nil 分支(防止启动 worker 干扰)
resetOpsErrorLoggerStateForTest(t)
opsErrorLogOnce.Do(func() {})
opsErrorLogMu.Lock()
opsErrorLogQueue = nil
opsErrorLogMu.Unlock()
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
}
package handler
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
usageRecordWorkerPool *service.UsageRecordWorkerPool
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
soraTLSEnabled := true
signKey := ""
mediaRoot := "/app/data/sora"
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
usageRecordWorkerPool: usageRecordWorkerPool,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.sora_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
var err error
body, err = sjson.SetBytes(body, "stream", true)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
setOpsRequestContext(c, reqModel, clientStream, body)
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forced
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
if platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
return
}
streamStarted := false
subscription, _ := middleware2.GetSubscriptionFromContext(c)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := generateOpenAISessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverBody []byte
var lastFailoverHeaders http.Header
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
if err != nil {
reqLog.Warn("sora.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int("last_upstream_status", lastFailoverStatus),
}
if rayID != "" {
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("last_upstream_content_type", contentType))
}
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
proxyBound := account.ProxyID != nil
proxyID := int64(0)
if account.ProxyID != nil {
proxyID = *account.ProxyID
}
tlsFingerprintEnabled := h.soraTLSEnabled
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("sora.account_wait_counter_increment_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
} else if !canWait {
reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
clientStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("sora.account_slot_acquire_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_switching", fields...)
continue
}
reqLog.Error("sora.forward_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
return
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("sora.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("switch_count", switchCount),
)
return
}
}
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func cloneHTTPHeaders(headers http.Header) http.Header {
if headers == nil {
return nil
}
return headers.Clone()
}
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
if headers != nil {
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
contentType = strings.TrimSpace(headers.Get("content-type"))
if contentType == "" {
contentType = strings.TrimSpace(headers.Get("Content-Type"))
}
}
rayID = soraerror.ExtractCloudflareRayID(headers, body)
return rayID, mitigated, contentType
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
message = strings.TrimSpace(message)
if message == "" {
return false
}
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
lower := strings.ToLower(message)
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
return false
}
}
return true
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// MediaProxy serves local Sora media files.
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned serves local Sora media files with signature verification.
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
c.Status(http.StatusNotFound)
return
}
query := c.Request.URL.Query()
if requireSignature {
if h.soraMediaSigningKey == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体签名未配置",
},
})
return
}
expiresStr := strings.TrimSpace(query.Get("expires"))
signature := strings.TrimSpace(query.Get("sig"))
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil || expires <= time.Now().Unix() {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名已过期",
},
})
return
}
query.Del("sig")
query.Del("expires")
signingQuery := query.Encode()
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名无效",
},
})
return
}
}
if strings.TrimSpace(h.soraMediaRoot) == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体目录未配置",
},
})
return
}
relative := strings.TrimPrefix(cleaned, "/")
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
if _, err := os.Stat(localPath); err != nil {
if os.IsNotExist(err) {
c.Status(http.StatusNotFound)
return
}
c.Status(http.StatusInternalServerError)
return
}
c.File(localPath)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment