Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
package handler
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type helperConcurrencyCacheStub struct {
mu sync.Mutex
accountSeq []bool
userSeq []bool
accountAcquireCalls int
userAcquireCalls int
accountReleaseCalls int
userReleaseCalls int
}
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.accountAcquireCalls++
if len(s.accountSeq) == 0 {
return false, nil
}
v := s.accountSeq[0]
s.accountSeq = s.accountSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.accountReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
out := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
out[accountID] = 0
}
return out, nil
}
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.userAcquireCalls++
if len(s.userSeq) == 0 {
return false, nil
}
v := s.userSeq[0]
s.userSeq = s.userSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.userReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
out := make(map[int64]*service.UserLoadInfo, len(users))
for _, user := range users {
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(method, path, nil)
return c, rec
}
func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
}`)
}
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "curl/8.6.0")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
SetClaudeCodeClientContext(c, nil, nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
// 缺少严格校验所需 header + body 字段
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
parsedReq := &service.ParsedRequest{
Model: "claude-3-5-sonnet-20241022",
System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
}
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
"model": "claude-3-5-sonnet-20241022",
"system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
})
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, true},
userSeq: []bool{false, true},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
require.False(t, streamStarted)
release()
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
})
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
release()
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
})
}
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, false, false},
}
concurrency := service.NewConcurrencyService(cache)
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
})
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.True(t, streamStarted)
require.Contains(t, rec.Body.String(), ":\n\n")
})
}
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
errCache := &helperConcurrencyCacheStubWithError{
err: errors.New("redis unavailable"),
}
concurrency := service.NewConcurrencyService(errCache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
require.Error(t, err)
require.Contains(t, err.Error(), "redis unavailable")
}
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
}
type helperConcurrencyCacheStubWithError struct {
helperConcurrencyCacheStub
err error
}
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, s.err
}
......@@ -7,24 +7,23 @@ import (
"encoding/hex"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
......@@ -143,6 +142,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gemini_v1beta.models",
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
......@@ -159,8 +165,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
......@@ -187,8 +194,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
......@@ -208,6 +216,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
......@@ -223,6 +232,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
......@@ -252,6 +262,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
}
// === Gemini 内容摘要会话 Fallback 逻辑 ===
......@@ -296,8 +314,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
reqLog.Info("gemini.digest_fallback_matched",
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
zap.Int64("account_id", foundAccountID),
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
)
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
......@@ -321,55 +342,54 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
}
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
reqLog.Info("gemini.sticky_session_account_switched",
zap.Int64("from_account_id", sessionBoundAccountID),
zap.Int64("to_account_id", account.ID),
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
reqLog.Info("gemini.sticky_session_binding_missing",
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
......@@ -388,9 +408,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted := false
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gemini.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
......@@ -412,6 +435,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&streamStarted,
)
if err != nil {
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
......@@ -420,7 +444,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
......@@ -429,8 +453,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
......@@ -443,27 +467,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch failoverAction {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
case FailoverCanceled:
return
}
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
......@@ -482,31 +498,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName),
zap.Int64("account_id", account.ID),
).Error("gemini.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", fs.SwitchCount),
)
return
}
}
......
......@@ -11,6 +11,7 @@ type AdminHandlers struct {
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
......@@ -25,6 +26,7 @@ type AdminHandlers struct {
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler
}
// Handlers contains all HTTP handlers
......@@ -39,6 +41,8 @@ type Handlers struct {
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
SoraClient *SoraClientHandler
Setting *SettingHandler
Totp *TotpHandler
}
......
package handler
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func executeUserIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
return
}
actorScope := "user:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
}
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
package handler
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userStoreUnavailableRepoStub struct{}
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
type userMemoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
return &userMemoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func withUserSubject(userID int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
}
}
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(nil)
var executed int
router := gin.New()
router.Use(withUserSubject(1))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, executed)
}
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.Use(withUserSubject(2))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "k1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed)
}
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newUserMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.Use(withUserSubject(3))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(80 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-user-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() { defer wg.Done(); status1, _ = call() }()
go func() { defer wg.Done(); status2, _ = call() }()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load())
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
base := logger.L()
if c != nil && c.Request != nil {
base = logger.FromContext(c.Request.Context())
}
if component != "" {
fields = append([]zap.Field{zap.String("component", component)}, fields...)
}
return base.With(fields...)
}
......@@ -5,19 +5,23 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
......@@ -25,6 +29,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
......@@ -36,6 +41,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *OpenAIGatewayHandler {
......@@ -51,6 +57,7 @@ func NewOpenAIGatewayHandler(
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
......@@ -60,6 +67,13 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
setOpenAIClientTransportHTTP(c)
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
......@@ -72,9 +86,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
// Read request body
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
......@@ -91,64 +115,51 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// 验证 model 必填
if reqModel == "" {
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
streamResult := gjson.GetBytes(body, "stream")
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
if previousResponseID != "" {
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
reqLog = reqLog.With(
zap.Bool("has_previous_response_id", true),
zap.String("previous_response_id_kind", previousResponseIDKind),
zap.Int("previous_response_id_len", len(previousResponseID)),
)
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "previous_response_id_looks_like_message_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
}
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
......@@ -157,54 +168,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
// User slot acquired: no longer waiting.
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
......@@ -213,12 +198,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
reqLog.Warn("openai.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
......@@ -228,67 +224,53 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
return
}
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
if previousResponseID != "" && selection != nil && selection.Account != nil {
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
}
reqLog.Debug("openai.account_schedule_decision",
zap.String("layer", scheduleDecision.Layer),
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
zap.Int("top_k", scheduleDecision.TopK),
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
zap.Float64("load_skew", scheduleDecision.LoadSkew),
)
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
setOpsSelectedAccount(c, account.ID)
// 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
......@@ -296,37 +278,629 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
// Error response already handled in Forward, just log
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
return
}
reqLog.Error("openai.forward_failed", fields...)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
return true
}
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
validation := service.ValidateFunctionCallOutputContext(reqBody)
if !validation.HasFunctionCallOutput {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
return true
}
if validation.HasFunctionCallOutputMissingCallID {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return false
}
if validation.HasItemReferenceForAllCallIDs {
return true
}
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return false
}
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
c *gin.Context,
userID int64,
userConcurrency int,
reqStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), bool) {
ctx := c.Request.Context()
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", *streamStarted)
return nil, false
}
if userAcquired {
return wrapReleaseOnDone(ctx, userReleaseFunc), true
}
maxWait := service.CalculateMaxWait(userConcurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
if waitErr != nil {
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return nil, false
}
waitCounted := waitErr == nil && canWait
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
}
}()
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
h.handleConcurrencyError(c, err, "user", *streamStarted)
return nil, false
}
// 槽位获取成功后,立刻退出等待计数。
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
waitCounted = false
}
return wrapReleaseOnDone(ctx, userReleaseFunc), true
}
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
c *gin.Context,
groupID *int64,
sessionHash string,
selection *service.AccountSelectionResult,
reqStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), bool) {
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
ctx := c.Request.Context()
account := selection.Account
if selection.Acquired {
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
}
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", *streamStarted)
return nil, false
}
if fastAcquired {
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
}
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
if waitErr != nil {
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
} else if !canWait {
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
return nil, false
}
accountWaitCounted := waitErr == nil && canWait
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
accountWaitCounted = false
}
}
defer releaseWait()
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
streamStarted,
)
if err != nil {
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", *streamStarted)
return nil, false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
}
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
// GET /openai/v1/responses (Upgrade: websocket)
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
if !isOpenAIWSUpgradeRequest(c.Request) {
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
return
}
setOpenAIClientTransportWS(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses_ws",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.Bool("openai_ws_mode", true),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
reqLog.Info("openai.websocket_ingress_started")
clientIP := ip.GetClientIP(c)
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
reqLog.Warn("openai.websocket_accept_failed",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("request_user_agent", userAgent),
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
)
return
}
defer func() {
_ = wsConn.CloseNow()
}()
wsConn.SetReadLimit(16 * 1024 * 1024)
ctx := c.Request.Context()
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
msgType, firstMessage, err := wsConn.Read(readCtx)
cancel()
if err != nil {
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_read_first_message_failed",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
zap.Duration("read_timeout", 30*time.Second),
)
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
return
}
if !gjson.ValidBytes(firstMessage) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
return
}
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
if reqModel == "" {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
return
}
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
return
}
reqLog = reqLog.With(
zap.Bool("ws_ingress", true),
zap.String("model", reqModel),
zap.Bool("has_previous_response_id", previousResponseID != ""),
zap.String("previous_response_id_kind", previousResponseIDKind),
)
setOpsRequestContext(c, reqModel, true, firstMessage)
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
if currentAccountRelease != nil {
currentAccountRelease()
currentAccountRelease = nil
}
if currentUserRelease != nil {
currentUserRelease()
currentUserRelease = nil
}
}
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
defer releaseTurnSlots()
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
return
}
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
c,
firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
)
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
return
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
}(result, account, userAgent, clientIP)
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil || result == nil {
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
}
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
recovered := recover()
if recovered == nil {
return
}
started := false
if streamStarted != nil {
started = *streamStarted
}
wroteFallback := h.ensureForwardErrorResponse(c, started)
requestLogger(c, "handler.openai_gateway.responses").Error(
"openai.responses_panic_recovered",
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Any("panic", recovered),
zap.ByteString("stack", debug.Stack()),
)
}
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
missing := h.missingResponsesDependencies()
if len(missing) == 0 {
return true
}
if reqLog == nil {
reqLog = requestLogger(c, "handler.openai_gateway.responses")
}
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
if c != nil && c.Writer != nil && !c.Writer.Written() {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Service temporarily unavailable",
},
})
}
return false
}
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
missing := make([]string, 0, 5)
if h == nil {
return append(missing, "handler")
}
if h.gatewayService == nil {
missing = append(missing, "gatewayService")
}
if h.billingCacheService == nil {
missing = append(missing, "billingCacheService")
}
if h.apiKeyService == nil {
missing = append(missing, "apiKeyService")
}
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
missing = append(missing, "concurrencyHelper")
}
return missing
}
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
}
v, ok := c.Get(key)
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case float64:
return int64(t), true
default:
return 0, false
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Any("panic", recovered),
).Error("openai.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
......@@ -397,8 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
......@@ -411,6 +985,25 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
h.errorResponse(c, status, errType, message)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
if wroteFallback {
return false
}
if c == nil || c.Writer == nil {
return false
}
return c.Writer.Written()
}
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
......@@ -420,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
},
})
}
func setOpenAIClientTransportHTTP(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
}
func setOpenAIClientTransportWS(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
}
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
gid := int64(0)
if groupID != nil {
gid = *groupID
}
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
}
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
return false
}
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
}
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
if conn == nil {
return
}
reason = strings.TrimSpace(reason)
if len(reason) > 120 {
reason = reason[:120]
}
_ = conn.Close(status, reason)
_ = conn.CloseNow()
}
func summarizeWSCloseErrorForLog(err error) (string, string) {
if err == nil {
return "-", "-"
}
statusCode := coderws.CloseStatus(err)
if statusCode == -1 {
return "-", "-"
}
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
closeReason := "-"
var closeErr coderws.CloseError
if errors.As(err, &closeErr) {
reason := strings.TrimSpace(closeErr.Reason)
if reason != "" {
closeReason = reason
}
}
return closeStatus, closeReason
}
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号的消息",
errType: "server_error",
message: `upstream returned "invalid" response`,
},
{
name: "包含反斜杠的消息",
errType: "server_error",
message: `path C:\Users\test\file.txt not found`,
},
{
name: "包含双引号和反斜杠的消息",
errType: "upstream_error",
message: `error parsing "key\value": unexpected token`,
},
{
name: "包含换行符的消息",
errType: "server_error",
message: "line1\nline2\ttab",
},
{
name: "普通消息",
errType: "upstream_error",
message: "Upstream service temporarily unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
// 验证 SSE 格式:event: error\ndata: {JSON}\n\n
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
// 提取 data 部分
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "应有 event 行和 data 行")
dataLine := lines[1]
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
jsonStr := strings.TrimPrefix(dataLine, "data: ")
// 验证 JSON 合法性
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
// 验证结构
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "应包含 error 对象")
assert.Equal(t, tt.errType, errorObj["type"])
assert.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
// 非流式应返回 JSON 响应
assert.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "test error", errorObj["message"])
}
func TestReadRequestBodyWithPrealloc(t *testing.T) {
payload := `{"model":"gpt-5","input":"hello"}`
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
req.ContentLength = int64(len(payload))
body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
require.NoError(t, err)
require.Equal(t, payload, string(body))
}
func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
req.Body = http.MaxBytesReader(rec, req.Body, 4)
_, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
require.Error(t, err)
var maxErr *http.MaxBytesError
require.ErrorAs(t, err, &maxErr)
}
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
})
t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
})
t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
})
t.Run("response_already_written_should_downgrade", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusForbidden, "already written")
require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
})
}
func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{}
streamStarted := false
require.NotPanics(t, func() {
func() {
defer h.recoverResponsesPanic(c, &streamStarted)
panic("test panic")
}()
})
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{}
streamStarted := false
require.NotPanics(t, func() {
func() {
defer h.recoverResponsesPanic(c, &streamStarted)
}()
})
require.False(t, c.Writer.Written())
assert.Equal(t, "", w.Body.String())
}
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
streamStarted := false
require.NotPanics(t, func() {
func() {
defer h.recoverResponsesPanic(c, &streamStarted)
panic("test panic")
}()
})
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
t.Run("nil_handler", func(t *testing.T) {
var h *OpenAIGatewayHandler
require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
})
t.Run("all_dependencies_missing", func(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.Equal(t,
[]string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
h.missingResponsesDependencies(),
)
})
t.Run("all_dependencies_present", func(t *testing.T) {
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{
concurrencyService: &service.ConcurrencyService{},
},
}
require.Empty(t, h.missingResponsesDependencies())
})
}
func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
t.Run("missing_dependencies_returns_503", func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{}
ok := h.ensureResponsesDependencies(c, nil)
require.False(t, ok)
require.Equal(t, http.StatusServiceUnavailable, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, exists := parsed["error"].(map[string]any)
require.True(t, exists)
assert.Equal(t, "api_error", errorObj["type"])
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
})
t.Run("already_written_response_not_overridden", func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
ok := h.ensureResponsesDependencies(c, nil)
require.False(t, ok)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
})
t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{
concurrencyService: &service.ConcurrencyService{},
},
}
ok := h.ensureResponsesDependencies(c, nil)
require.True(t, ok)
require.False(t, c.Writer.Written())
assert.Equal(t, "", w.Body.String())
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
c.Request.Header.Set("Content-Type", "application/json")
groupID := int64(2)
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
ID: 10,
GroupID: &groupID,
})
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 1,
})
// 故意使用未初始化依赖,验证快速失败而不是崩溃。
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.Responses(c)
})
require.Equal(t, http.StatusServiceUnavailable, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "api_error", errorObj["type"])
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
}
func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
c.Request.Header.Set("Content-Type", "application/json")
h := &OpenAIGatewayHandler{}
h.Responses(c)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
}
func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
))
c.Request.Header.Set("Content-Type", "application/json")
groupID := int64(2)
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
ID: 101,
GroupID: &groupID,
User: &service.User{ID: 1},
})
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 1,
})
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
h.Responses(c)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
c.Request.Header.Set("Upgrade", "websocket")
c.Request.Header.Set("Connection", "Upgrade")
h := &OpenAIGatewayHandler{}
h.ResponsesWebSocket(c)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
}
func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
h := &OpenAIGatewayHandler{}
h.ResponsesWebSocket(c)
require.Equal(t, http.StatusUpgradeRequired, w.Code)
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
}
func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, _, err = clientConn.Read(readCtx)
cancelRead()
require.Error(t, err)
var closeErr coderws.CloseError
require.ErrorAs(t, err, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
}
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
gin.SetMode(gin.TestMode)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return false, errors.New("user slot unavailable")
},
}
h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, _, err = clientConn.Read(readCtx)
cancelRead()
require.Error(t, err)
var closeErr coderws.CloseError
require.ErrorAs(t, err, &closeErr)
require.Equal(t, coderws.StatusInternalError, closeErr.Code)
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
}
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
setOpenAIClientTransportHTTP(c)
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
}
func TestSetOpenAIClientTransportWS(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
setOpenAIClientTransportWS(c)
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantStream bool
}{
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
{"model 缺失", `{"stream":true}`, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := []byte(tt.body)
modelResult := gjson.GetBytes(body, "model")
model := ""
if modelResult.Type == gjson.String {
model = modelResult.String()
}
stream := gjson.GetBytes(body, "stream").Bool()
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
})
}
}
// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验
func TestOpenAIHandler_GjsonValidation(t *testing.T) {
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid json`)))
// model 为数字 → 类型不是 gjson.String,应被拒绝
body := []byte(`{"model":123}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, modelResult.Exists())
require.NotEqual(t, gjson.String, modelResult.Type)
// model 为 null → 类型不是 gjson.String,应被拒绝
body2 := []byte(`{"model":null}`)
modelResult2 := gjson.GetBytes(body2, "model")
require.True(t, modelResult2.Exists())
require.NotEqual(t, gjson.String, modelResult2.Type)
// stream 为 string → 类型既不是 True 也不是 False,应被拒绝
body3 := []byte(`{"model":"gpt-4","stream":"true"}`)
streamResult := gjson.GetBytes(body3, "stream")
require.True(t, streamResult.Exists())
require.NotEqual(t, gjson.True, streamResult.Type)
require.NotEqual(t, gjson.False, streamResult.Type)
// stream 为 int → 同上
body4 := []byte(`{"model":"gpt-4","stream":1}`)
streamResult2 := gjson.GetBytes(body4, "stream")
require.True(t, streamResult2.Exists())
require.NotEqual(t, gjson.True, streamResult2.Type)
require.NotEqual(t, gjson.False, streamResult2.Type)
}
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
// 测试 1:无 instructions → 注入
body := []byte(`{"model":"gpt-4"}`)
existing := gjson.GetBytes(body, "instructions").String()
require.Empty(t, existing)
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
require.NoError(t, err)
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
// 测试 2:已有 instructions → 不覆盖
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
existing2 := gjson.GetBytes(body2, "instructions").String()
require.Equal(t, "existing", existing2)
// 测试 3:空白 instructions → 注入
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
require.Empty(t, existing3)
// 测试 4:sjson.SetBytes 返回错误时不应 panic
// 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理
validBody := []byte(`{"model":"gpt-4"}`)
result, setErr := sjson.SetBytes(validBody, "instructions", "hello")
require.NoError(t, setErr)
require.True(t, gjson.ValidBytes(result))
}
func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
t.Helper()
if cache == nil {
cache = &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
}
return &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
}
}
func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
t.Helper()
groupID := int64(2)
apiKey := &service.APIKey{
ID: 101,
GroupID: &groupID,
User: &service.User{ID: subject.UserID},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), subject)
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
return httptest.NewServer(router)
}
......@@ -41,9 +41,8 @@ const (
)
type opsErrorLogJob struct {
ops *service.OpsService
entry *service.OpsInsertErrorLogInput
requestBody []byte
ops *service.OpsService
entry *service.OpsInsertErrorLogInput
}
var (
......@@ -58,6 +57,7 @@ var (
opsErrorLogEnqueued atomic.Int64
opsErrorLogDropped atomic.Int64
opsErrorLogProcessed atomic.Int64
opsErrorLogSanitized atomic.Int64
opsErrorLogLastDropLogAt atomic.Int64
......@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
}
}()
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
_ = job.ops.RecordError(ctx, job.entry, nil)
cancel()
opsErrorLogProcessed.Add(1)
}()
......@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
}
}
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
if ops == nil || entry == nil {
return
}
......@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
}
select {
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
opsErrorLogQueueLen.Add(1)
opsErrorLogEnqueued.Add(1)
default:
......@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
return opsErrorLogProcessed.Load()
}
func OpsErrorLogSanitizedTotal() int64 {
return opsErrorLogSanitized.Load()
}
func maybeLogOpsErrorLogDrop() {
now := time.Now().Unix()
......@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
queueCap := OpsErrorLogQueueCapacity()
log.Printf(
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
queued,
queueCap,
opsErrorLogEnqueued.Load(),
opsErrorLogDropped.Load(),
opsErrorLogProcessed.Load(),
opsErrorLogSanitized.Load(),
)
}
......@@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
if c == nil {
return
}
model = strings.TrimSpace(model)
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
if len(requestBody) > 0 {
c.Set(opsRequestBodyKey, requestBody)
}
if c.Request != nil && model != "" {
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
c.Request = c.Request.WithContext(ctx)
}
}
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
v, ok := c.Get(opsRequestBodyKey)
if !ok {
return
}
raw, ok := v.([]byte)
if !ok || len(raw) == 0 {
return
}
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
opsErrorLogSanitized.Add(1)
}
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
if c == nil || accountID <= 0 {
return
}
c.Set(opsAccountIDKey, accountID)
if c.Request != nil {
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
if len(platform) > 0 {
p := strings.TrimSpace(platform[0])
if p != "" {
ctx = context.WithValue(ctx, ctxkey.Platform, p)
}
}
c.Request = c.Request.WithContext(ctx)
}
}
type opsCaptureWriter struct {
......@@ -275,6 +311,35 @@ type opsCaptureWriter struct {
buf bytes.Buffer
}
const opsCaptureWriterLimit = 64 * 1024
var opsCaptureWriterPool = sync.Pool{
New: func() any {
return &opsCaptureWriter{limit: opsCaptureWriterLimit}
},
}
func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
if !ok || w == nil {
w = &opsCaptureWriter{}
}
w.ResponseWriter = rw
w.limit = opsCaptureWriterLimit
w.buf.Reset()
return w
}
func releaseOpsCaptureWriter(w *opsCaptureWriter) {
if w == nil {
return
}
w.ResponseWriter = nil
w.limit = opsCaptureWriterLimit
w.buf.Reset()
opsCaptureWriterPool.Put(w)
}
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
remaining := w.limit - w.buf.Len()
......@@ -306,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
// - Streaming errors after the response has started (SSE) may still need explicit logging.
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
return func(c *gin.Context) {
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
originalWriter := c.Writer
w := acquireOpsCaptureWriter(originalWriter)
defer func() {
// Restore the original writer before returning so outer middlewares
// don't observe a pooled wrapper that has been released.
if c.Writer == w {
c.Writer = originalWriter
}
releaseOpsCaptureWriter(w)
}()
c.Writer = w
c.Next()
......@@ -507,6 +581,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0,
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
if apiKey != nil {
entry.APIKeyID = &apiKey.ID
......@@ -528,14 +603,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
......@@ -544,7 +614,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
}
}
enqueueOpsErrorLog(ops, entry, requestBody)
enqueueOpsErrorLog(ops, entry)
return
}
......@@ -592,8 +662,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
requestID = c.Writer.Header().Get("x-request-id")
}
phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
......@@ -615,8 +687,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase,
ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
Severity: classifyOpsSeverity(parsed.ErrorType, status),
ErrorType: normalizedType,
Severity: classifyOpsSeverity(normalizedType, status),
StatusCode: status,
IsBusinessLimited: isBusinessLimited,
IsCountTokens: isCountTokensRequest(c),
......@@ -628,10 +700,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorSource: errorSource,
ErrorOwner: errorOwner,
IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
IsRetryable: classifyOpsIsRetryable(normalizedType, status),
RetryCount: 0,
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
// Capture upstream error context set by gateway services (if present).
// This does NOT affect the client response; it enriches Ops troubleshooting data.
......@@ -707,17 +780,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Do NOT store Authorization/Cookie/etc.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
enqueueOpsErrorLog(ops, entry, requestBody)
enqueueOpsErrorLog(ops, entry)
}
}
......@@ -760,6 +828,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
return &s
}
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
}
func getContextLatencyMs(c *gin.Context, key string) *int64 {
if c == nil || strings.TrimSpace(key) == "" {
return nil
}
v, ok := c.Get(key)
if !ok {
return nil
}
var ms int64
switch t := v.(type) {
case int:
ms = int64(t)
case int32:
ms = int64(t)
case int64:
ms = t
case float64:
ms = int64(t)
default:
return nil
}
if ms < 0 {
return nil
}
return &ms
}
type parsedOpsError struct {
ErrorType string
Message string
......@@ -835,8 +941,29 @@ func guessPlatformFromPath(path string) string {
}
}
// isKnownOpsErrorType returns true if t is a recognized error type used by the
// ops classification pipeline. Upstream proxies sometimes return garbage values
// (e.g. the Go-serialized literal "<nil>") which would pollute phase/severity
// classification if accepted blindly.
func isKnownOpsErrorType(t string) bool {
switch t {
case "invalid_request_error",
"authentication_error",
"rate_limit_error",
"billing_error",
"subscription_error",
"upstream_error",
"overloaded_error",
"api_error",
"not_found_error",
"forbidden_error":
return true
}
return false
}
func normalizeOpsErrorType(errType string, code string) string {
if errType != "" {
if errType != "" && isKnownOpsErrorType(errType) {
return errType
}
switch strings.TrimSpace(code) {
......
package handler
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func resetOpsErrorLoggerStateForTest(t *testing.T) {
t.Helper()
opsErrorLogMu.Lock()
ch := opsErrorLogQueue
opsErrorLogQueue = nil
opsErrorLogStopping = true
opsErrorLogMu.Unlock()
if ch != nil {
close(ch)
}
opsErrorLogWorkersWg.Wait()
opsErrorLogOnce = sync.Once{}
opsErrorLogStopOnce = sync.Once{}
opsErrorLogWorkersWg = sync.WaitGroup{}
opsErrorLogMu = sync.RWMutex{}
opsErrorLogStopping = false
opsErrorLogQueueLen.Store(0)
opsErrorLogEnqueued.Store(0)
opsErrorLogDropped.Store(0)
opsErrorLogProcessed.Store(0)
opsErrorLogSanitized.Store(0)
opsErrorLogLastDropLogAt.Store(0)
opsErrorLogShutdownCh = make(chan struct{})
opsErrorLogShutdownOnce = sync.Once{}
opsErrorLogDrained.Store(false)
}
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
setOpsRequestContext(c, "claude-3", false, raw)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(c, entry)
require.NotNil(t, entry.RequestBodyBytes)
require.Equal(t, len(raw), *entry.RequestBodyBytes)
require.NotNil(t, entry.RequestBodyJSON)
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
}
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
raw := []byte("not-json")
setOpsRequestContext(c, "claude-3", false, raw)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.NotNil(t, entry.RequestBodyBytes)
require.Equal(t, len(raw), *entry.RequestBodyBytes)
require.False(t, entry.RequestBodyTruncated)
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
}
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
opsErrorLogOnce.Do(func() {})
opsErrorLogMu.Lock()
opsErrorLogQueue = make(chan opsErrorLogJob, 1)
opsErrorLogMu.Unlock()
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
enqueueOpsErrorLog(ops, entry)
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal())
require.Equal(t, int64(1), OpsErrorLogDroppedTotal())
require.Equal(t, int64(1), OpsErrorLogQueueLength())
}
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(nil, entry)
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 无请求体 key
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
require.False(t, entry.RequestBodyTruncated)
// 错误类型
c.Set(opsRequestBodyKey, "not-bytes")
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
// 空 bytes
c.Set(opsRequestBodyKey, []byte{})
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
}
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
// nil 入参分支
enqueueOpsErrorLog(nil, entry)
enqueueOpsErrorLog(ops, nil)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
// shutdown 分支
close(opsErrorLogShutdownCh)
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
// stopping 分支
resetOpsErrorLoggerStateForTest(t)
opsErrorLogMu.Lock()
opsErrorLogStopping = true
opsErrorLogMu.Unlock()
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
// queue nil 分支(防止启动 worker 干扰)
resetOpsErrorLoggerStateForTest(t)
opsErrorLogOnce.Do(func() {})
opsErrorLogMu.Lock()
opsErrorLogQueue = nil
opsErrorLogMu.Unlock()
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
}
func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
writer := acquireOpsCaptureWriter(c.Writer)
require.NotNil(t, writer)
_, err := writer.buf.WriteString("temp-error-body")
require.NoError(t, err)
releaseOpsCaptureWriter(writer)
reused := acquireOpsCaptureWriter(c.Writer)
defer releaseOpsCaptureWriter(reused)
require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
}
func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(middleware2.Recovery())
r.Use(middleware2.RequestLogger())
r.Use(middleware2.Logger())
r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
c.Status(http.StatusNoContent)
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
require.NotPanics(t, func() {
r.ServeHTTP(rec, req)
})
require.Equal(t, http.StatusNoContent, rec.Code)
}
func TestIsKnownOpsErrorType(t *testing.T) {
known := []string{
"invalid_request_error",
"authentication_error",
"rate_limit_error",
"billing_error",
"subscription_error",
"upstream_error",
"overloaded_error",
"api_error",
"not_found_error",
"forbidden_error",
}
for _, k := range known {
require.True(t, isKnownOpsErrorType(k), "expected known: %s", k)
}
unknown := []string{"<nil>", "null", "", "random_error", "some_new_type", "<nil>\u003e"}
for _, u := range unknown {
require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u)
}
}
func TestNormalizeOpsErrorType(t *testing.T) {
tests := []struct {
name string
errType string
code string
want string
}{
// Known types pass through.
{"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"},
{"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"},
{"known upstream_error", "upstream_error", "", "upstream_error"},
// Unknown/garbage types are rejected and fall through to code-based or default.
{"nil literal from upstream", "<nil>", "", "api_error"},
{"null string", "null", "", "api_error"},
{"random string", "something_weird", "", "api_error"},
// Unknown type but known code still maps correctly.
{"nil with INSUFFICIENT_BALANCE code", "<nil>", "INSUFFICIENT_BALANCE", "billing_error"},
{"nil with USAGE_LIMIT_EXCEEDED code", "<nil>", "USAGE_LIMIT_EXCEEDED", "subscription_error"},
// Empty type falls through to code-based mapping.
{"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"},
{"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"},
{"empty type no code", "", "", "api_error"},
// Known type overrides conflicting code-based mapping.
{"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeOpsErrorType(tt.errType, tt.code)
require.Equal(t, tt.want, got)
})
}
}
......@@ -32,25 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
Version: h.version,
})
}
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// 上游模型缓存 TTL
modelCacheTTL = 1 * time.Hour // 上游获取成功
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type SoraClientHandler struct {
genService *service.SoraGenerationService
quotaService *service.SoraQuotaService
s3Storage *service.SoraS3Storage
soraGatewayService *service.SoraGatewayService
gatewayService *service.GatewayService
mediaStorage *service.SoraMediaStorage
apiKeyService *service.APIKeyService
// 上游模型缓存
modelCacheMu sync.RWMutex
cachedFamilies []service.SoraModelFamily
modelCacheTime time.Time
modelCacheUpstream bool // 是否来自上游(决定 TTL)
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func NewSoraClientHandler(
genService *service.SoraGenerationService,
quotaService *service.SoraQuotaService,
s3Storage *service.SoraS3Storage,
soraGatewayService *service.SoraGatewayService,
gatewayService *service.GatewayService,
mediaStorage *service.SoraMediaStorage,
apiKeyService *service.APIKeyService,
) *SoraClientHandler {
return &SoraClientHandler{
genService: genService,
quotaService: quotaService,
s3Storage: s3Storage,
soraGatewayService: soraGatewayService,
gatewayService: gatewayService,
mediaStorage: mediaStorage,
apiKeyService: apiKeyService,
}
}
// GenerateRequest 生成请求。
type GenerateRequest struct {
Model string `json:"model" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
MediaType string `json:"media_type"` // video / image,默认 video
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func (h *SoraClientHandler) Generate(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
var req GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
return
}
if req.MediaType == "" {
req.MediaType = "video"
}
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
// 并发数检查(最多 3 个)
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if activeCount >= 3 {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if h.quotaService != nil {
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusForbidden, err.Error())
return
}
}
// 获取 API Key ID 和 Group ID
var apiKeyID *int64
var groupID *int64
if req.APIKeyID != nil && h.apiKeyService != nil {
// 前端传递了 api_key_id,需要校验
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
if err != nil {
response.Error(c, http.StatusBadRequest, "API Key 不存在")
return
}
if apiKey.UserID != userID {
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
return
}
if apiKey.Status != service.StatusAPIKeyActive {
response.Error(c, http.StatusForbidden, "API Key 不可用")
return
}
apiKeyID = &apiKey.ID
groupID = apiKey.GroupID
} else if id, ok := c.Get("api_key_id"); ok {
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if v, ok := id.(int64); ok {
apiKeyID = &v
}
}
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
if err != nil {
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
response.ErrorFrom(c, err)
return
}
// 启动后台异步生成 goroutine
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
response.Success(c, gin.H{
"generation_id": gen.ID,
"status": gen.Status,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// 标记为生成中
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
genID,
userID,
groupIDForLog(groupID),
model,
mediaType,
videoCount,
strings.TrimSpace(imageInput) != "",
len(strings.TrimSpace(prompt)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if groupID == nil {
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
}
if h.gatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
return
}
// 选择 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
genID,
userID,
groupIDForLog(groupID),
model,
err,
)
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
genID,
userID,
groupIDForLog(groupID),
model,
account.ID,
account.Name,
account.Platform,
account.Type,
)
// 构建 chat completions 请求体(非流式)
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
if h.soraGatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
return
}
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
recorder := httptest.NewRecorder()
mockGinCtx, _ := gin.CreateTestContext(recorder)
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
// 调用 Forward(非流式)
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
err,
)
// 检查是否已取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
return
}
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
return
}
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
if mediaURL == "" {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
)
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
return
}
// 检查任务是否已被取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
return
}
// 三层降级存储:S3 → 本地 → 上游临时 URL
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
usageAdded := false
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
return
}
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
return
}
usageAdded = true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen, _ = h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
// 标记完成
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
}
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
func (h *SoraClientHandler) storeMediaWithDegradation(
ctx context.Context, userID int64, mediaType string,
mediaURL string, mediaURLs []string,
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
urls := mediaURLs
if len(urls) == 0 {
urls = []string{mediaURL}
}
// 第一层:尝试 S3
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
keys := make([]string, 0, len(urls))
var totalSize int64
allOK := true
for _, u := range urls {
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
allOK = false
// 清理已上传的文件
if len(keys) > 0 {
_ = h.s3Storage.DeleteObjects(ctx, keys)
}
break
}
keys = append(keys, key)
totalSize += size
}
if allOK && len(keys) > 0 {
accessURLs := make([]string, 0, len(keys))
for _, key := range keys {
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
_ = h.s3Storage.DeleteObjects(ctx, keys)
allOK = false
break
}
accessURLs = append(accessURLs, accessURL)
}
if allOK && len(accessURLs) > 0 {
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
}
}
}
// 第二层:尝试本地存储
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
if err == nil && len(storedPaths) > 0 {
firstPath := storedPaths[0]
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
if sizeErr != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
}
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
}
// 第三层:保留上游临时 URL
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
body := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"stream": false,
}
if imageInput != "" {
body["image_input"] = imageInput
}
if videoCount > 1 {
body["video_count"] = videoCount
}
b, _ := json.Marshal(body)
return b
}
func normalizeVideoCount(mediaType string, videoCount int) int {
if mediaType != "video" {
return 1
}
if videoCount <= 0 {
return 1
}
if videoCount > 3 {
return 3
}
return videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径:ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
// 优先从 ForwardResult 获取(OAuth 路径)
if result != nil && result.MediaURL != "" {
// 尝试从响应体获取完整 URL 列表
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return result.MediaURL, []string{result.MediaURL}
}
// 从响应体解析(APIKey 路径)
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return "", nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func parseMediaURLsFromBody(body []byte) []string {
if len(body) == 0 {
return nil
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil
}
// 优先 media_urls(多图数组)
if rawURLs, ok := resp["media_urls"]; ok {
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
urls := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok && s != "" {
urls = append(urls, s)
}
}
if len(urls) > 0 {
return urls
}
}
}
// 回退到 media_url(单个 URL)
if url, ok := resp["media_url"].(string); ok && url != "" {
return []string{url}
}
return nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
params := service.SoraGenerationListParams{
UserID: userID,
Status: c.Query("status"),
StorageType: c.Query("storage_type"),
MediaType: c.Query("media_type"),
Page: page,
PageSize: pageSize,
}
gens, total, err := h.genService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 为 S3 记录动态生成预签名 URL
for _, gen := range gens {
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
}
response.Success(c, gin.H{
"data": gens,
"total": total,
"page": page,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
response.Success(c, gen)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
paths := gen.MediaURLs
if len(paths) == 0 && gen.MediaURL != "" {
paths = []string{gen.MediaURL}
}
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
}
}
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
response.Success(c, gin.H{"message": "已删除"})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
if h.quotaService == nil {
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
return
}
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, quota)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
// 权限校验
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = gen
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
if errors.Is(err, service.ErrSoraGenerationNotActive) {
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"message": "已取消"})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
if gen.StorageType != service.SoraStorageTypeUpstream {
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
return
}
if gen.MediaURL == "" {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
return
}
sourceURLs := gen.MediaURLs
if len(sourceURLs) == 0 && gen.MediaURL != "" {
sourceURLs = []string{gen.MediaURL}
}
if len(sourceURLs) == 0 {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
uploadedKeys := make([]string, 0, len(sourceURLs))
accessURLs := make([]string, 0, len(sourceURLs))
var totalSize int64
for _, sourceURL := range sourceURLs {
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
if uploadErr != nil {
if len(uploadedKeys) > 0 {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
}
var upstreamErr *service.UpstreamDownloadError
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
return
}
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
return
}
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
if err != nil {
uploadedKeys = append(uploadedKeys, objectKey)
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
return
}
uploadedKeys = append(uploadedKeys, objectKey)
accessURLs = append(accessURLs, accessURL)
totalSize += fileSize
}
usageAdded := false
if totalSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
return
}
usageAdded = true
}
if err := h.genService.UpdateStorageForCompleted(
c.Request.Context(),
id,
accessURLs[0],
accessURLs,
service.SoraStorageTypeS3,
uploadedKeys,
totalSize,
); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"message": "已保存到 S3",
"object_key": uploadedKeys[0],
"object_keys": uploadedKeys,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
s3Healthy := false
if s3Enabled {
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
}
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
response.Success(c, gin.H{
"s3_enabled": s3Enabled,
"s3_healthy": s3Healthy,
"local_enabled": localEnabled,
})
}
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
switch storageType {
case service.SoraStorageTypeS3:
if h.s3Storage != nil && len(s3Keys) > 0 {
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
}
}
case service.SoraStorageTypeLocal:
if h.mediaStorage != nil && len(localPaths) > 0 {
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func getUserIDFromContext(c *gin.Context) int64 {
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return subject.UserID
}
if id, ok := c.Get("user_id"); ok {
switch v := id.(type) {
case int64:
return v
case float64:
return int64(v)
case string:
n, _ := strconv.ParseInt(v, 10, 64)
return n
}
}
// 尝试从 JWT claims 获取
if id, ok := c.Get("userID"); ok {
if v, ok := id.(int64); ok {
return v
}
}
return 0
}
func groupIDForLog(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
func trimForLog(raw string, maxLen int) string {
trimmed := strings.TrimSpace(raw)
if maxLen <= 0 || len(trimmed) <= maxLen {
return trimmed
}
return trimmed[:maxLen] + "...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func (h *SoraClientHandler) GetModels(c *gin.Context) {
families := h.getModelFamilies(c.Request.Context())
response.Success(c, families)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
// 读锁检查缓存
h.modelCacheMu.RLock()
ttl := modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
families := h.cachedFamilies
h.modelCacheMu.RUnlock()
return families
}
h.modelCacheMu.RUnlock()
// 写锁更新缓存
h.modelCacheMu.Lock()
defer h.modelCacheMu.Unlock()
// double-check
ttl = modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
return h.cachedFamilies
}
// 尝试从上游获取
families, err := h.fetchUpstreamModels(ctx)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
families = service.BuildSoraModelFamilies()
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = false
return families
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = true
return families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
if h.gatewayService == nil {
return nil, fmt.Errorf("gatewayService 未初始化")
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
// 选择一个 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
if err != nil {
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
}
// 仅支持 API Key 类型账号
if account.Type != service.AccountTypeAPIKey {
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
}
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return nil, fmt.Errorf("账号缺少 api_key")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return nil, fmt.Errorf("账号缺少 base_url")
}
// 构建上游模型列表请求
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求上游失败: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 解析 OpenAI 格式的模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(modelsResp.Data) == 0 {
return nil, fmt.Errorf("上游返回空模型列表")
}
// 提取模型 ID
modelIDs := make([]string, 0, len(modelsResp.Data))
for _, m := range modelsResp.Data {
modelIDs = append(modelIDs, m.ID)
}
// 转换为模型家族
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
if len(families) == 0 {
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
}
return families, nil
}
//go:build unit
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
// ==================== Stub: SoraGenerationRepository ====================
var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
type stubSoraGenRepo struct {
gens map[int64]*service.SoraGeneration
nextID int64
createErr error
getErr error
updateErr error
deleteErr error
listErr error
countErr error
countValue int64
// 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
updateCallCount *int32
updateFailAfterN int32
// 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
getByIDCallCount int32
getByIDOverrideAfterN int32 // 0 = 不覆盖
getByIDOverrideStatus string
}
func newStubSoraGenRepo() *stubSoraGenRepo {
return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
}
func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
if r.createErr != nil {
return r.createErr
}
gen.ID = r.nextID
r.nextID++
r.gens[gen.ID] = gen
return nil
}
func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
if r.getErr != nil {
return nil, r.getErr
}
gen, ok := r.gens[id]
if !ok {
return nil, fmt.Errorf("not found")
}
// 条件性状态覆盖:模拟外部取消等场景
if r.getByIDOverrideAfterN > 0 {
n := atomic.AddInt32(&r.getByIDCallCount, 1)
if n > r.getByIDOverrideAfterN {
cp := *gen
cp.Status = r.getByIDOverrideStatus
return &cp, nil
}
}
return gen, nil
}
func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
// 条件性失败:前 N 次成功,之后失败
if r.updateCallCount != nil {
n := atomic.AddInt32(r.updateCallCount, 1)
if n > r.updateFailAfterN {
return fmt.Errorf("conditional update error (call #%d)", n)
}
}
if r.updateErr != nil {
return r.updateErr
}
r.gens[gen.ID] = gen
return nil
}
func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
if r.deleteErr != nil {
return r.deleteErr
}
delete(r.gens, id)
return nil
}
func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
if r.listErr != nil {
return nil, 0, r.listErr
}
var result []*service.SoraGeneration
for _, gen := range r.gens {
if gen.UserID != params.UserID {
continue
}
result = append(result, gen)
}
return result, int64(len(result)), nil
}
func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
return r.countValue, nil
}
// ==================== 辅助函数 ====================
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
genService := service.NewSoraGenerationService(repo, nil, nil)
return &SoraClientHandler{genService: genService}
}
func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
if body != "" {
c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
} else {
c.Request = httptest.NewRequest(method, path, nil)
}
if userID > 0 {
c.Set("user_id", userID)
}
return c, rec
}
func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
t.Helper()
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
return resp
}
// ==================== 纯函数测试: buildAsyncRequestBody ====================
func TestBuildAsyncRequestBody(t *testing.T) {
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "sora2-landscape-10s", parsed["model"])
require.Equal(t, false, parsed["stream"])
msgs := parsed["messages"].([]any)
require.Len(t, msgs, 1)
msg := msgs[0].(map[string]any)
require.Equal(t, "user", msg["role"])
require.Equal(t, "一只猫在跳舞", msg["content"])
}
func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
body := buildAsyncRequestBody("gpt-image", "", "", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "gpt-image", parsed["model"])
msgs := parsed["messages"].([]any)
msg := msgs[0].(map[string]any)
require.Equal(t, "", msg["content"])
}
func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
}
func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, float64(3), parsed["video_count"])
}
func TestNormalizeVideoCount(t *testing.T) {
require.Equal(t, 1, normalizeVideoCount("video", 0))
require.Equal(t, 2, normalizeVideoCount("video", 2))
require.Equal(t, 3, normalizeVideoCount("video", 5))
require.Equal(t, 1, normalizeVideoCount("image", 3))
}
// ==================== 纯函数测试: parseMediaURLsFromBody ====================
func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
}
func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
}
func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody(nil))
require.Nil(t, parseMediaURLsFromBody([]byte{}))
}
func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
}
func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
}
func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
}
func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
}
func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
urls := parseMediaURLsFromBody([]byte(body))
require.Len(t, urls, 2)
require.Equal(t, "https://multi.com/a.mp4", urls[0])
}
func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
}
func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
}
func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
// media_urls 不是 string 数组
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
}
func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
}
// ==================== 纯函数测试: extractMediaURLsFromResult ====================
func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(result, recorder)
require.Equal(t, "https://oauth.com/video.mp4", url)
require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
}
func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
recorder := httptest.NewRecorder()
_, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
url, urls := extractMediaURLsFromResult(result, recorder)
require.Equal(t, "https://body.com/1.mp4", url)
require.Len(t, urls, 2)
}
func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
recorder := httptest.NewRecorder()
_, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
url, urls := extractMediaURLsFromResult(nil, recorder)
require.Equal(t, "https://upstream.com/video.mp4", url)
require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
}
func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(nil, recorder)
require.Empty(t, url)
require.Nil(t, urls)
}
func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
result := &service.ForwardResult{MediaURL: ""}
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(result, recorder)
require.Empty(t, url)
require.Nil(t, urls)
}
// ==================== getUserIDFromContext ====================
func TestGetUserIDFromContext_Int64(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", int64(42))
require.Equal(t, int64(42), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
require.Equal(t, int64(777), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_Float64(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", float64(99))
require.Equal(t, int64(99), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_String(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", "123")
require.Equal(t, int64(123), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("userID", int64(55))
require.Equal(t, int64(55), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_NoID(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
require.Equal(t, int64(0), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_InvalidString(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", "not-a-number")
require.Equal(t, int64(0), getUserIDFromContext(c))
}
// ==================== Handler: Generate ====================
func TestGenerate_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
h.Generate(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGenerate_BadRequest_MissingModel(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_TooManyRequests(t *testing.T) {
repo := newStubSoraGenRepo()
repo.countValue = 3
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
func TestGenerate_CountError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.countErr = fmt.Errorf("db error")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGenerate_Success(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.NotZero(t, data["generation_id"])
require.Equal(t, "pending", data["status"])
}
func TestGenerate_DefaultMediaType(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "video", repo.gens[1].MediaType)
}
func TestGenerate_ImageMediaType(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "image", repo.gens[1].MediaType)
}
func TestGenerate_CreatePendingError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.createErr = fmt.Errorf("create failed")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGenerate_APIKeyInContext(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
c.Set("api_key_id", int64(42))
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_NoAPIKeyInContext(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_ConcurrencyBoundary(t *testing.T) {
// activeCount == 2 应该允许
repo := newStubSoraGenRepo()
repo.countValue = 2
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== Handler: ListGenerations ====================
func TestListGenerations_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
h.ListGenerations(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestListGenerations_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
repo.nextID = 3
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
items := data["data"].([]any)
require.Len(t, items, 2)
require.Equal(t, float64(2), data["total"])
}
func TestListGenerations_ListError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.listErr = fmt.Errorf("db error")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestListGenerations_DefaultPagination(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
// 不传分页参数,应默认 page=1 page_size=20
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, float64(1), data["page"])
}
// ==================== Handler: GetGeneration ====================
func TestGetGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGetGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.GetGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGetGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.GetGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestGetGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestGetGeneration_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, float64(1), data["id"])
}
// ==================== Handler: DeleteGeneration ====================
func TestDeleteGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestDeleteGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDeleteGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestDeleteGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestDeleteGeneration_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
_, exists := repo.gens[1]
require.False(t, exists)
}
// ==================== Handler: CancelGeneration ====================
func TestCancelGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestCancelGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestCancelGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestCancelGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestCancelGeneration_Pending(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "cancelled", repo.gens[1].Status)
}
func TestCancelGeneration_Generating(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "cancelled", repo.gens[1].Status)
}
func TestCancelGeneration_Completed(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestCancelGeneration_Failed(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestCancelGeneration_Cancelled(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
// ==================== Handler: GetQuota ====================
func TestGetQuota_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
h.GetQuota(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGetQuota_NilQuotaService(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
h.GetQuota(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, "unlimited", data["source"])
}
// ==================== Handler: GetModels ====================
func TestGetModels(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
h.GetModels(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].([]any)
require.Len(t, data, 4)
// 验证类型分布
videoCount, imageCount := 0, 0
for _, item := range data {
m := item.(map[string]any)
if m["type"] == "video" {
videoCount++
} else if m["type"] == "image" {
imageCount++
}
}
require.Equal(t, 3, videoCount)
require.Equal(t, 1, imageCount)
}
// ==================== Handler: GetStorageStatus ====================
func TestGetStorageStatus_NilS3(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, false, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
require.Equal(t, false, data["local_enabled"])
}
func TestGetStorageStatus_LocalEnabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, false, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
require.Equal(t, true, data["local_enabled"])
}
// ==================== Handler: SaveToStorage ====================
func TestSaveToStorage_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestSaveToStorage_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestSaveToStorage_NotUpstream(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_S3Nil(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
}
func TestSaveToStorage_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
// ==================== storeMediaWithDegradation — nil guard 路径 ====================
func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
h := &SoraClientHandler{}
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://upstream.com/v.mp4", url)
require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
require.Nil(t, keys)
require.Equal(t, int64(0), size)
}
func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
h := &SoraClientHandler{}
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://a.com/1.mp4", url)
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
require.Nil(t, keys)
require.Equal(t, int64(0), size)
}
func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
h := &SoraClientHandler{}
url, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://upstream.com/v.mp4", url)
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
type stubUserRepoForHandler struct {
users map[int64]*service.User
updateErr error
}
func newStubUserRepoForHandler() *stubUserRepoForHandler {
return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
}
func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
if u, ok := r.users[id]; ok {
return u, nil
}
return nil, fmt.Errorf("user not found")
}
func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
if r.updateErr != nil {
return r.updateErr
}
r.users[user.ID] = user
return nil
}
func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
return nil, nil
}
func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
return nil, nil
}
func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== NewSoraClientHandler ====================
func TestNewSoraClientHandler(t *testing.T) {
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, h)
}
func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, h)
require.Nil(t, h.apiKeyService)
}
// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
type stubAPIKeyRepoForHandler struct {
keys map[int64]*service.APIKey
getErr error
}
func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
}
func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
if r.getErr != nil {
return nil, r.getErr
}
if k, ok := r.keys[id]; ok {
return k, nil
}
return nil, fmt.Errorf("api key not found: %d", id)
}
func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
return "", 0, nil
}
func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
// newTestAPIKeyService 创建测试用的 APIKeyService
func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
}
// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
// 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
groupID := int64(5)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: &groupID,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.NotZero(t, data["generation_id"])
// 验证 api_key_id 已关联到生成记录
gen := repo.gens[1]
require.NotNil(t, gen.APIKeyID)
require.Equal(t, int64(42), *gen.APIKeyID)
}
func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
// 前端传递不存在的 api_key_id → 400
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
}
func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
// 前端传递别人的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 999, // 属于 user 999
Status: service.StatusAPIKeyActive,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
}
func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
// 前端传递已禁用的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyDisabled,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
}
func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
// 前端传递配额耗尽的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyQuotaExhausted,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
// 前端传递已过期的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyExpired,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
// apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService} // apiKeyService = nil
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
// api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: nil, // 无分组
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
// 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
// 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
groupID := int64(10)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: &groupID,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// 应使用 body 中的 api_key_id=42,而不是 context 中的 99
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
// 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
c.Set("api_key_id", int64(99))
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// 应使用 context 中的 api_key_id=99
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
}
func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
// JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
// JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
// api_key_id=0 不存在 → 400
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
// groupID 不为 nil → 不设置 ForcePlatform
// gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
gid := int64(5)
h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
// groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
// 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
require.Equal(t, "cancelled", repo.gens[1].Status)
}
// ==================== GenerateRequest JSON 解析 ====================
func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
// 验证 api_key_id 在 JSON 中正确解析为 *int64
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
require.NoError(t, err)
require.NotNil(t, req.APIKeyID)
require.Equal(t, int64(42), *req.APIKeyID)
}
func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
// 不传 api_key_id → 解析后为 nil
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
require.NoError(t, err)
require.Nil(t, req.APIKeyID)
}
func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
// api_key_id: null → 解析后为 nil
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
require.NoError(t, err)
require.Nil(t, req.APIKeyID)
}
func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
// 全字段解析
var req GenerateRequest
err := json.Unmarshal([]byte(`{
"model":"sora2-landscape-10s",
"prompt":"test prompt",
"media_type":"video",
"video_count":2,
"image_input":"data:image/png;base64,abc",
"api_key_id":100
}`), &req)
require.NoError(t, err)
require.Equal(t, "sora2-landscape-10s", req.Model)
require.Equal(t, "test prompt", req.Prompt)
require.Equal(t, "video", req.MediaType)
require.Equal(t, 2, req.VideoCount)
require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
require.NotNil(t, req.APIKeyID)
require.Equal(t, int64(100), *req.APIKeyID)
}
func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
// api_key_id 为 nil 时 JSON 序列化应省略
req := GenerateRequest{Model: "sora2", Prompt: "test"}
b, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(b, &parsed))
_, hasAPIKeyID := parsed["api_key_id"]
require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
}
func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
// api_key_id 不为 nil 时 JSON 序列化应包含
id := int64(42)
req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
b, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(b, &parsed))
require.Equal(t, float64(42), parsed["api_key_id"])
}
// ==================== GetQuota: 有配额服务 ====================
func TestGetQuota_WithQuotaService_Success(t *testing.T) {
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 3 * 1024 * 1024,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
h.GetQuota(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, "user", data["source"])
require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
require.Equal(t, float64(3*1024*1024), data["used_bytes"])
}
func TestGetQuota_WithQuotaService_Error(t *testing.T) {
// 用户不存在时 GetQuota 返回错误
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
h.GetQuota(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== Generate: 配额检查 ====================
func TestGenerate_QuotaCheckFailed(t *testing.T) {
// 配额超限时返回 429
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 1024,
SoraStorageUsedBytes: 1025, // 已超限
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
func TestGenerate_QuotaCheckPassed(t *testing.T) {
// 配额充足时允许生成
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
type stubSettingRepoForHandler struct {
values map[string]string
}
func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
if values == nil {
values = make(map[string]string)
}
return &stubSettingRepoForHandler{values: values}
}
func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
if v, ok := r.values[key]; ok {
return &service.Setting{Key: key, Value: v}, nil
}
return nil, service.ErrSettingNotFound
}
func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
if v, ok := r.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
r.values[key] = value
return nil
}
func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string)
for _, k := range keys {
if v, ok := r.values[k]; ok {
result[k] = v
}
}
return result, nil
}
func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
for k, v := range settings {
r.values[k] = v
}
return nil
}
func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
return r.values, nil
}
func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
delete(r.values, key)
return nil
}
// ==================== S3 / MediaStorage 辅助函数 ====================
// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
settingRepo := newStubSettingRepoForHandler(map[string]string{
"sora_s3_enabled": "true",
"sora_s3_endpoint": endpoint,
"sora_s3_region": "us-east-1",
"sora_s3_bucket": "test-bucket",
"sora_s3_access_key_id": "AKIATEST",
"sora_s3_secret_access_key": "test-secret",
"sora_s3_prefix": "sora",
"sora_s3_force_path_style": "true",
})
settingService := service.NewSettingService(settingRepo, &config.Config{})
return service.NewSoraS3Storage(settingService)
}
// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
func newFakeSourceServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/mp4")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("fake video data for test"))
}))
}
// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
func newFakeS3Server(mode string) *httptest.Server {
var counter atomic.Int32
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
_ = r.Body.Close()
switch mode {
case "ok":
w.Header().Set("ETag", `"test-etag"`)
w.WriteHeader(http.StatusOK)
case "fail":
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`))
case "fail-second":
n := counter.Add(1)
if n <= 1 {
w.Header().Set("ETag", `"test-etag"`)
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`))
}
}
}))
}
// ==================== processGeneration 直接调用测试 ====================
func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
repo.updateErr = fmt.Errorf("db error")
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
// 直接调用(非 goroutine),MarkGenerating 失败 → 早退
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
// repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
// 因此 ErrorMessage 为空(证明未调用 MarkFailed)
require.Equal(t, "generating", repo.gens[1].Status)
require.Empty(t, repo.gens[1].ErrorMessage)
}
func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
// gatewayService 未设置 → MarkFailed
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
// ==================== storeMediaWithDegradation: S3 路径 ====================
func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeS3, storageType)
require.Len(t, s3Keys, 1)
require.NotEmpty(t, s3Keys[0])
require.Len(t, storedURLs, 1)
require.Equal(t, storedURL, storedURLs[0])
require.Contains(t, storedURL, fakeS3.URL)
require.Contains(t, storedURL, "/test-bucket/")
require.Greater(t, fileSize, int64(0))
}
func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
)
require.Equal(t, service.SoraStorageTypeS3, storageType)
require.Len(t, s3Keys, 2)
require.Len(t, storedURLs, 2)
require.Equal(t, storedURL, storedURLs[0])
require.Contains(t, storedURLs[0], fakeS3.URL)
require.Contains(t, storedURLs[1], fakeS3.URL)
require.Greater(t, fileSize, int64(0))
}
func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
// 上游返回 404 → 下载失败 → S3 上传不会开始
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer badSource.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
}
func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
// S3 失败,降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Nil(t, s3Keys)
}
func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
)
// 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Nil(t, s3Keys)
}
// ==================== storeMediaWithDegradation: 本地存储路径 ====================
func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
// 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: "/dev/null/invalid_dir",
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
)
// 本地存储失败,降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
}
func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
DownloadTimeoutSeconds: 5,
MaxDownloadBytes: 10 * 1024 * 1024,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeLocal, storageType)
require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
}
func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
DownloadTimeoutSeconds: 5,
MaxDownloadBytes: 10 * 1024 * 1024,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{
s3Storage: s3Storage,
mediaStorage: mediaStorage,
}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
// S3 失败 → 本地存储成功
require.Equal(t, service.SoraStorageTypeLocal, storageType)
}
// ==================== SaveToStorage: S3 路径 ====================
func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "S3")
}
func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer expiredServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: expiredServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusGone, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "过期")
}
func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Contains(t, data["message"], "S3")
require.NotEmpty(t, data["object_key"])
// 验证记录已更新为 S3 存储
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
}
func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{
sourceServer.URL + "/v1.mp4",
sourceServer.URL + "/v2.mp4",
},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Len(t, data["object_keys"].([]any), 2)
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
require.Len(t, repo.gens[1].S3ObjectKeys, 2)
require.Len(t, repo.gens[1].MediaURLs, 2)
}
func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 100 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证配额已累加
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
}
func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== GetStorageStatus: S3 路径 ====================
func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
// S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, true, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
}
func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, true, data["s3_enabled"])
require.Equal(t, true, data["s3_healthy"])
}
// ==================== Stub: AccountRepository (用于 GatewayService) ====================
var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
type stubAccountRepoForHandler struct {
accounts []service.Account
}
func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, fmt.Errorf("account not found")
}
func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
return false, nil
}
func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
return nil
}
func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
return nil
}
func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
return nil
}
func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
return nil
}
func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
return nil
}
func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
type stubSoraClientForHandler struct {
videoStatus *service.SoraVideoTaskStatus
}
func (s *stubSoraClientForHandler) Enabled() bool { return true }
func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
return s.videoStatus, nil
}
// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}
// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
}
// ==================== processGeneration: 更多路径测试 ====================
func TestProcessGeneration_SelectAccountError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
}
func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// 提供可用账号使 SelectAccountForModel 成功
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// soraGatewayService 为 nil
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
}
func TestProcessGeneration_ForwardError(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// SoraClient 返回视频任务失败
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "failed",
ErrorMsg: "content policy violation",
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
}
func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
// 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
repo.getByIDOverrideAfterN = 1
repo.getByIDOverrideStatus = "cancelled"
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
require.Equal(t, "generating", repo.gens[1].Status)
}
func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// SoraClient 返回 completed 但无 URL
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: nil, // 无 URL
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
}
func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
// 第 2 次返回 "cancelled" 状态,模拟外部取消
repo.getByIDOverrideAfterN = 1
repo.getByIDOverrideStatus = "cancelled"
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
require.Equal(t, "generating", repo.gens[1].Status)
}
func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
// 无 S3 和本地存储,降级到 upstream
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "completed", repo.gens[1].Status)
require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
require.NotEmpty(t, repo.gens[1].MediaURL)
}
func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{sourceServer.URL + "/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
s3Storage := newS3StorageForHandler(fakeS3.URL)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
s3Storage: s3Storage,
quotaService: quotaService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "completed", repo.gens[1].Status)
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
// 验证配额已累加
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
}
func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
repo.updateCallCount = new(int32)
repo.updateFailAfterN = 1
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
// MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
// 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
// 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
require.Equal(t, "completed", repo.gens[1].Status)
}
// ==================== cleanupStoredMedia 直接测试 ====================
func TestCleanupStoredMedia_S3Path(t *testing.T) {
// S3 清理路径:s3Storage 为 nil 时不 panic
h := &SoraClientHandler{}
// 不应 panic
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
func TestCleanupStoredMedia_LocalPath(t *testing.T) {
// 本地清理路径:mediaStorage 为 nil 时不 panic
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
}
func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
// upstream 类型不清理
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
}
func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
// 空 keys 不触发清理
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
}
// ==================== DeleteGeneration: 本地存储清理路径 ====================
func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "video/test.mp4",
MediaURLs: []string{"video/test.mp4"},
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
// MediaURLs 为空,使用 MediaURL 作为清理路径
tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "video/test.mp4",
MediaURLs: nil, // 空
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
// 非本地存储类型 → 跳过清理
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeUpstream,
MediaURL: "https://upstream.com/v.mp4",
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDeleteGeneration_DeleteError(t *testing.T) {
// repo.Delete 出错
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
repo.deleteErr = fmt.Errorf("delete failed")
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
// ==================== fetchUpstreamModels 测试 ====================
func TestFetchUpstreamModels_NilGateway(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
h := &SoraClientHandler{}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "gatewayService 未初始化")
}
func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "选择 Sora 账号失败")
}
func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "不支持模型同步")
}
func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"base_url": "https://sora.test"}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "api_key")
}
func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test"}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
}
func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "状态码 500")
}
func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not json"))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "解析响应失败")
}
func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "空模型列表")
}
func TestFetchUpstreamModels_Success(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求头
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
families, err := h.fetchUpstreamModels(context.Background())
require.NoError(t, err)
require.NotEmpty(t, families)
}
func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "未能从上游模型列表中识别")
}
// ==================== getModelFamilies 缓存测试 ====================
func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
// gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
h := &SoraClientHandler{}
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
// 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
families2 := h.getModelFamilies(context.Background())
require.Equal(t, families, families2)
require.False(t, h.modelCacheUpstream)
}
func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
require.True(t, h.modelCacheUpstream)
// 第二次调用命中缓存
families2 := h.getModelFamilies(context.Background())
require.Equal(t, families, families2)
}
func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
// 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
h := &SoraClientHandler{
cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
modelCacheUpstream: false,
}
// gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
// 缓存已刷新,不再是 "old"
found := false
for _, f := range families {
if f.ID == "old" {
found = true
}
}
require.False(t, found, "过期缓存应被刷新")
}
// ==================== processGeneration: groupID 与 ForcePlatform ====================
func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
// groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// 空账号列表 → SelectAccountForModel 失败
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
}
// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
// quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
// ==================== Generate: CreatePending 并发限制错误 ====================
// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
type stubSoraGenRepoWithAtomicCreate struct {
stubSoraGenRepo
limitErr error
}
func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
if r.limitErr != nil {
return r.limitErr
}
return r.stubSoraGenRepo.Create(context.Background(), gen)
}
func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
repo := &stubSoraGenRepoWithAtomicCreate{
stubSoraGenRepo: *newStubSoraGenRepo(),
limitErr: service.ErrSoraGenerationConcurrencyLimit,
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "3")
}
// ==================== SaveToStorage: 配额超限 ====================
func TestSaveToStorage_QuotaExceeded(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户配额已满
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10,
SoraStorageUsedBytes: 10,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== SaveToStorage: MediaURLs 全为空 ====================
func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: "",
MediaURLs: []string{},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "已过期")
}
// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 100 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
}
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
}
// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "nonexistent/video.mp4",
MediaURLs: []string{"nonexistent/video.mp4"},
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== CancelGeneration: 任务已结束冲突 ====================
func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
package handler
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
usageRecordWorkerPool *service.UsageRecordWorkerPool
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
soraTLSEnabled := true
signKey := ""
mediaRoot := "/app/data/sora"
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
usageRecordWorkerPool: usageRecordWorkerPool,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.sora_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
var err error
body, err = sjson.SetBytes(body, "stream", true)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
setOpsRequestContext(c, reqModel, clientStream, body)
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forced
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
if platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
return
}
streamStarted := false
subscription, _ := middleware2.GetSubscriptionFromContext(c)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := generateOpenAISessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverBody []byte
var lastFailoverHeaders http.Header
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
if err != nil {
reqLog.Warn("sora.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int("last_upstream_status", lastFailoverStatus),
}
if rayID != "" {
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("last_upstream_content_type", contentType))
}
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
proxyBound := account.ProxyID != nil
proxyID := int64(0)
if account.ProxyID != nil {
proxyID = *account.ProxyID
}
tlsFingerprintEnabled := h.soraTLSEnabled
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("sora.account_wait_counter_increment_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
} else if !canWait {
reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
clientStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("sora.account_slot_acquire_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_switching", fields...)
continue
}
reqLog.Error("sora.forward_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
return
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("sora.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("switch_count", switchCount),
)
return
}
}
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Any("panic", recovered),
).Error("sora.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func cloneHTTPHeaders(headers http.Header) http.Header {
if headers == nil {
return nil
}
return headers.Clone()
}
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
if headers != nil {
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
contentType = strings.TrimSpace(headers.Get("content-type"))
if contentType == "" {
contentType = strings.TrimSpace(headers.Get("Content-Type"))
}
}
rayID = soraerror.ExtractCloudflareRayID(headers, body)
return rayID, mitigated, contentType
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
message = strings.TrimSpace(message)
if message == "" {
return false
}
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
lower := strings.ToLower(message)
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
return false
}
}
return true
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// MediaProxy serves local Sora media files.
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned serves local Sora media files with signature verification.
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
c.Status(http.StatusNotFound)
return
}
query := c.Request.URL.Query()
if requireSignature {
if h.soraMediaSigningKey == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体签名未配置",
},
})
return
}
expiresStr := strings.TrimSpace(query.Get("expires"))
signature := strings.TrimSpace(query.Get("sig"))
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil || expires <= time.Now().Unix() {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名已过期",
},
})
return
}
query.Del("sig")
query.Del("expires")
signingQuery := query.Encode()
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名无效",
},
})
return
}
}
if strings.TrimSpace(h.soraMediaRoot) == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体目录未配置",
},
})
return
}
relative := strings.TrimPrefix(cleaned, "/")
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
if _, err := os.Stat(localPath); err != nil {
if os.IsNotExist(err) {
c.Status(http.StatusNotFound)
return
}
c.Status(http.StatusInternalServerError)
return
}
c.File(localPath)
}
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
var _ service.SoraClient = (*stubSoraClient)(nil)
var _ service.AccountRepository = (*stubAccountRepo)(nil)
var _ service.GroupRepository = (*stubGroupRepo)(nil)
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
type stubSoraClient struct {
imageURLs []string
}
func (s *stubSoraClient) Enabled() bool { return true }
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
return "upload", nil
}
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
type stubAccountRepo struct {
accounts map[int64]*service.Account
}
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if acc, ok := r.accounts[id]; ok {
return acc, nil
}
return nil, service.ErrAccountNotFound
}
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
var result []*service.Account
for _, id := range ids {
if acc, ok := r.accounts[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
_, ok := r.accounts[id]
return ok, nil
}
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return map[string]int64{}, nil
}
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
var result []service.Account
for _, acc := range r.accounts {
for _, platform := range platforms {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
break
}
}
}
return result, nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
type stubGroupRepo struct {
group *service.Group
}
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, nil
}
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
type stubUsageLogRepo struct{}
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return true, nil
}
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, nil
}
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, nil
}
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
RunMode: config.RunModeSimple,
Gateway: config.GatewayConfig{
SoraStreamMode: "force",
MaxAccountSwitches: 1,
Scheduling: config.GatewaySchedulingConfig{
LoadBatchEnabled: false,
},
},
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.test",
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
groupRepo := &stubGroupRepo{group: group}
usageLogRepo := &stubUsageLogRepo{}
deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
t.Cleanup(func() {
billingCacheService.Stop()
})
gatewayService := service.NewGatewayService(
accountRepo,
groupRepo,
usageLogRepo,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,
concurrencyService,
billingService,
nil,
billingCacheService,
nil,
nil,
deferredService,
nil,
testutil.StubSessionLimitCache{},
nil, // rpmCache
nil, // digestStore
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
apiKey := &service.APIKey{
ID: 1,
UserID: 1,
Status: service.StatusActive,
GroupID: &group.ID,
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
handler.ChatCompletions(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp["media_url"])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func TestSoraHandler_StreamForcing(t *testing.T) {
// 测试 1:stream=false 时 sjson 强制修改为 true
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
clientStream := gjson.GetBytes(body, "stream").Bool()
require.False(t, clientStream)
newBody, err := sjson.SetBytes(body, "stream", true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
// 测试 2:stream=true 时不修改
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
require.True(t, gjson.GetBytes(body2, "stream").Bool())
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
require.False(t, gjson.GetBytes(body3, "stream").Bool())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func TestSoraHandler_ValidationExtraction(t *testing.T) {
// model 缺失
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
// model 为数字 → 类型不是 gjson.String,应被拒绝
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
modelResult1b := gjson.GetBytes(body1b, "model")
require.True(t, modelResult1b.Exists())
require.NotEqual(t, gjson.String, modelResult1b.Type)
// messages 缺失
body2 := []byte(`{"model":"sora"}`)
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
// messages 不是 JSON 数组(字符串)
body3 := []byte(`{"model":"sora","messages":"not array"}`)
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
// messages 是对象而非数组 → IsArray 返回 false
body4 := []byte(`{"model":"sora","messages":{}}`)
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
body5 := []byte(`{"model":"sora","messages":[]}`)
msgsResult := gjson.GetBytes(body5, "messages")
require.True(t, msgsResult.IsArray())
require.Equal(t, 0, len(msgsResult.Array()))
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
gin.SetMode(gin.TestMode)
// 从 body 提取 prompt_cache_key
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
hash := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash)
// 无 prompt_cache_key 且无 header → 空 hash
body2 := []byte(`{"model":"sora"}`)
hash2 := generateOpenAISessionHash(c, body2)
require.Empty(t, hash2)
// header 优先于 body
c.Request.Header.Set("session_id", "from-header")
hash3 := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
}
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号",
errType: "upstream_error",
message: `upstream returned "invalid" payload`,
},
{
name: "包含换行和制表符",
errType: "rate_limit_error",
message: "line1\nline2\ttab",
},
{
name: "包含反斜杠",
errType: "upstream_error",
message: `path C:\Users\test\file.txt not found`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "JSON 中应包含 error 对象")
require.Equal(t, tt.errType, errorObj["type"])
require.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"))
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
}
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
headers := http.Header{}
headers.Set("cf-mitigated", "challenge")
headers.Set("content-type", "text/html")
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
require.Equal(t, "9cff2d62d83bb98d", rayID)
require.Equal(t, "challenge", mitigated)
require.Equal(t, "text/html", contentType)
}
......@@ -2,6 +2,7 @@ package handler
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
......@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse additional filters
model := c.Query("model")
var requestType *int16
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
......@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
......@@ -392,7 +403,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.ErrorFrom(c, err)
return
......
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userUsageRepoCapture struct {
service.UsageLogRepository
listFilters usagestats.UsageLogFilters
}
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
usageSvc := service.NewUsageService(repo, nil, nil, nil)
handler := NewUsageHandler(usageSvc, nil)
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
c.Next()
})
router.GET("/usage", handler.List)
return router
}
func TestUserUsageListRequestTypePriority(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, int64(42), repo.listFilters.UserID)
require.NotNil(t, repo.listFilters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
require.Nil(t, repo.listFilters.Stream)
}
func TestUserUsageListInvalidRequestType(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestUserUsageListInvalidStream(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
t.Helper()
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
return pool
}
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &GatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &GatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &GatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
package handler
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助
// 复用 ConcurrencyHelper 的退避 + SSE ping 模式
type UserMsgQueueHelper struct {
queueService *service.UserMessageQueueService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewUserMsgQueueHelper 创建用户消息串行队列辅助
func NewUserMsgQueueHelper(
queueService *service.UserMessageQueueService,
pingFormat SSEPingFormat,
pingInterval time.Duration,
) *UserMsgQueueHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &UserMsgQueueHelper{
queueService: queueService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping
// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放
func (h *UserMsgQueueHelper) AcquireWithWait(
c *gin.Context,
accountID int64,
baseRPM int,
isStream bool,
streamStarted *bool,
timeout time.Duration,
reqLog *zap.Logger,
) (releaseFunc func(), err error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 先尝试立即获取
result, err := h.queueService.TryAcquire(ctx, accountID)
if err != nil {
return nil, err // fail-open 已在 service 层处理
}
if result.Acquired {
// 获取成功,执行 RPM 自适应延迟
if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil {
if ctx.Err() != nil {
// 延迟期间 context 取消,释放锁
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
bgCancel()
return nil, ctx.Err()
}
}
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
}
// 需要等待:指数退避轮询
return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog)
}
// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping
func (h *UserMsgQueueHelper) waitForLockWithPing(
c *gin.Context,
ctx context.Context,
accountID int64,
baseRPM int,
isStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), error) {
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
needPing = false
}
}
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("umq wait timeout for account %d", accountID)
case <-pingCh:
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return nil, err
}
flusher.Flush()
case <-timer.C:
result, err := h.queueService.TryAcquire(ctx, accountID)
if err != nil {
return nil, err
}
if result.Acquired {
// 获取成功,执行 RPM 自适应延迟
if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil {
if ctx.Err() != nil {
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
bgCancel()
return nil, ctx.Err()
}
}
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
}
backoff = nextBackoff(backoff)
timer.Reset(backoff)
}
}
}
// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次)
func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() {
var once sync.Once
return func() {
once.Do(func() {
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer bgCancel()
if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil {
reqLog.Warn("gateway.umq_release_failed",
zap.Int64("account_id", accountID),
zap.Error(err),
)
} else {
reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID))
}
})
}
}
// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping
// 不获取串行锁,不阻塞并发。返回后即可转发请求。
func (h *UserMsgQueueHelper) ThrottleWithPing(
c *gin.Context,
accountID int64,
baseRPM int,
isStream bool,
streamStarted *bool,
timeout time.Duration,
reqLog *zap.Logger,
) error {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM)
if delay <= 0 {
return nil
}
reqLog.Debug("gateway.umq_throttle_delay",
zap.Int64("account_id", accountID),
zap.Duration("delay", delay),
)
// 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑)
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
flusher, _ = c.Writer.(http.Flusher)
if flusher == nil {
needPing = false
}
}
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
timer := time.NewTimer(delay)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-pingCh:
// SSE ping 逻辑(与 waitForLockWithPing 一致)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return err
}
flusher.Flush()
case <-timer.C:
return nil
}
}
}
......@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
dataManagementHandler *admin.DataManagementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
......@@ -28,6 +29,7 @@ func ProvideAdminHandlers(
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
......@@ -35,6 +37,7 @@ func ProvideAdminHandlers(
Group: groupHandler,
Account: accountHandler,
Announcement: announcementHandler,
DataManagement: dataManagementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
......@@ -49,12 +52,13 @@ func ProvideAdminHandlers(
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler,
}
}
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService)
func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService, lockService)
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
......@@ -74,8 +78,12 @@ func ProvideHandlers(
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers {
return &Handlers{
Auth: authHandler,
......@@ -88,6 +96,8 @@ func ProvideHandlers(
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
SoraClient: soraClientHandler,
Setting: settingHandler,
Totp: totpHandler,
}
......@@ -105,6 +115,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
......@@ -114,6 +125,7 @@ var ProviderSet = wire.NewSet(
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewDataManagementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
......@@ -128,6 +140,7 @@ var ProviderSet = wire.NewSet(
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment