Commit 62e80c60 authored by erio's avatar erio
Browse files

revert: completely remove all Sora functionality

parent dbb248df
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// 上游模型缓存 TTL
modelCacheTTL = 1 * time.Hour // 上游获取成功
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type SoraClientHandler struct {
genService *service.SoraGenerationService
quotaService *service.SoraQuotaService
s3Storage *service.SoraS3Storage
soraGatewayService *service.SoraGatewayService
gatewayService *service.GatewayService
mediaStorage *service.SoraMediaStorage
apiKeyService *service.APIKeyService
// 上游模型缓存
modelCacheMu sync.RWMutex
cachedFamilies []service.SoraModelFamily
modelCacheTime time.Time
modelCacheUpstream bool // 是否来自上游(决定 TTL)
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func NewSoraClientHandler(
genService *service.SoraGenerationService,
quotaService *service.SoraQuotaService,
s3Storage *service.SoraS3Storage,
soraGatewayService *service.SoraGatewayService,
gatewayService *service.GatewayService,
mediaStorage *service.SoraMediaStorage,
apiKeyService *service.APIKeyService,
) *SoraClientHandler {
return &SoraClientHandler{
genService: genService,
quotaService: quotaService,
s3Storage: s3Storage,
soraGatewayService: soraGatewayService,
gatewayService: gatewayService,
mediaStorage: mediaStorage,
apiKeyService: apiKeyService,
}
}
// GenerateRequest 生成请求。
type GenerateRequest struct {
Model string `json:"model" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
MediaType string `json:"media_type"` // video / image,默认 video
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func (h *SoraClientHandler) Generate(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
var req GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
return
}
if req.MediaType == "" {
req.MediaType = "video"
}
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
// 并发数检查(最多 3 个)
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if activeCount >= 3 {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if h.quotaService != nil {
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusForbidden, err.Error())
return
}
}
// 获取 API Key ID 和 Group ID
var apiKeyID *int64
var groupID *int64
if req.APIKeyID != nil && h.apiKeyService != nil {
// 前端传递了 api_key_id,需要校验
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
if err != nil {
response.Error(c, http.StatusBadRequest, "API Key 不存在")
return
}
if apiKey.UserID != userID {
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
return
}
if apiKey.Status != service.StatusAPIKeyActive {
response.Error(c, http.StatusForbidden, "API Key 不可用")
return
}
apiKeyID = &apiKey.ID
groupID = apiKey.GroupID
} else if id, ok := c.Get("api_key_id"); ok {
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if v, ok := id.(int64); ok {
apiKeyID = &v
}
}
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
if err != nil {
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
response.ErrorFrom(c, err)
return
}
// 启动后台异步生成 goroutine
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
response.Success(c, gin.H{
"generation_id": gen.ID,
"status": gen.Status,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// 标记为生成中
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
genID,
userID,
groupIDForLog(groupID),
model,
mediaType,
videoCount,
strings.TrimSpace(imageInput) != "",
len(strings.TrimSpace(prompt)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if groupID == nil {
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
}
if h.gatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
return
}
// 选择 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
genID,
userID,
groupIDForLog(groupID),
model,
err,
)
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
genID,
userID,
groupIDForLog(groupID),
model,
account.ID,
account.Name,
account.Platform,
account.Type,
)
// 构建 chat completions 请求体(非流式)
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
if h.soraGatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
return
}
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
recorder := httptest.NewRecorder()
mockGinCtx, _ := gin.CreateTestContext(recorder)
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
// 调用 Forward(非流式)
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
err,
)
// 检查是否已取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
return
}
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
return
}
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
if mediaURL == "" {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
)
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
return
}
// 检查任务是否已被取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
return
}
// 三层降级存储:S3 → 本地 → 上游临时 URL
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
usageAdded := false
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
return
}
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
return
}
usageAdded = true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen, _ = h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
// 标记完成
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
}
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
func (h *SoraClientHandler) storeMediaWithDegradation(
ctx context.Context, userID int64, mediaType string,
mediaURL string, mediaURLs []string,
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
urls := mediaURLs
if len(urls) == 0 {
urls = []string{mediaURL}
}
// 第一层:尝试 S3
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
keys := make([]string, 0, len(urls))
var totalSize int64
allOK := true
for _, u := range urls {
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
allOK = false
// 清理已上传的文件
if len(keys) > 0 {
_ = h.s3Storage.DeleteObjects(ctx, keys)
}
break
}
keys = append(keys, key)
totalSize += size
}
if allOK && len(keys) > 0 {
accessURLs := make([]string, 0, len(keys))
for _, key := range keys {
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
_ = h.s3Storage.DeleteObjects(ctx, keys)
allOK = false
break
}
accessURLs = append(accessURLs, accessURL)
}
if allOK && len(accessURLs) > 0 {
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
}
}
}
// 第二层:尝试本地存储
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
if err == nil && len(storedPaths) > 0 {
firstPath := storedPaths[0]
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
if sizeErr != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
}
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
}
// 第三层:保留上游临时 URL
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
body := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"stream": false,
}
if imageInput != "" {
body["image_input"] = imageInput
}
if videoCount > 1 {
body["video_count"] = videoCount
}
b, _ := json.Marshal(body)
return b
}
func normalizeVideoCount(mediaType string, videoCount int) int {
if mediaType != "video" {
return 1
}
if videoCount <= 0 {
return 1
}
if videoCount > 3 {
return 3
}
return videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径:ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
// 优先从 ForwardResult 获取(OAuth 路径)
if result != nil && result.MediaURL != "" {
// 尝试从响应体获取完整 URL 列表
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return result.MediaURL, []string{result.MediaURL}
}
// 从响应体解析(APIKey 路径)
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return "", nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func parseMediaURLsFromBody(body []byte) []string {
if len(body) == 0 {
return nil
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil
}
// 优先 media_urls(多图数组)
if rawURLs, ok := resp["media_urls"]; ok {
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
urls := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok && s != "" {
urls = append(urls, s)
}
}
if len(urls) > 0 {
return urls
}
}
}
// 回退到 media_url(单个 URL)
if url, ok := resp["media_url"].(string); ok && url != "" {
return []string{url}
}
return nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
params := service.SoraGenerationListParams{
UserID: userID,
Status: c.Query("status"),
StorageType: c.Query("storage_type"),
MediaType: c.Query("media_type"),
Page: page,
PageSize: pageSize,
}
gens, total, err := h.genService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 为 S3 记录动态生成预签名 URL
for _, gen := range gens {
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
}
response.Success(c, gin.H{
"data": gens,
"total": total,
"page": page,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
response.Success(c, gen)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
paths := gen.MediaURLs
if len(paths) == 0 && gen.MediaURL != "" {
paths = []string{gen.MediaURL}
}
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
}
}
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
response.Success(c, gin.H{"message": "已删除"})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
if h.quotaService == nil {
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
return
}
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, quota)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
// 权限校验
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = gen
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
if errors.Is(err, service.ErrSoraGenerationNotActive) {
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"message": "已取消"})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
if gen.StorageType != service.SoraStorageTypeUpstream {
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
return
}
if gen.MediaURL == "" {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
return
}
sourceURLs := gen.MediaURLs
if len(sourceURLs) == 0 && gen.MediaURL != "" {
sourceURLs = []string{gen.MediaURL}
}
if len(sourceURLs) == 0 {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
uploadedKeys := make([]string, 0, len(sourceURLs))
accessURLs := make([]string, 0, len(sourceURLs))
var totalSize int64
for _, sourceURL := range sourceURLs {
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
if uploadErr != nil {
if len(uploadedKeys) > 0 {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
}
var upstreamErr *service.UpstreamDownloadError
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
return
}
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
return
}
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
if err != nil {
uploadedKeys = append(uploadedKeys, objectKey)
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
return
}
uploadedKeys = append(uploadedKeys, objectKey)
accessURLs = append(accessURLs, accessURL)
totalSize += fileSize
}
usageAdded := false
if totalSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
return
}
usageAdded = true
}
if err := h.genService.UpdateStorageForCompleted(
c.Request.Context(),
id,
accessURLs[0],
accessURLs,
service.SoraStorageTypeS3,
uploadedKeys,
totalSize,
); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"message": "已保存到 S3",
"object_key": uploadedKeys[0],
"object_keys": uploadedKeys,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
s3Healthy := false
if s3Enabled {
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
}
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
response.Success(c, gin.H{
"s3_enabled": s3Enabled,
"s3_healthy": s3Healthy,
"local_enabled": localEnabled,
})
}
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
switch storageType {
case service.SoraStorageTypeS3:
if h.s3Storage != nil && len(s3Keys) > 0 {
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
}
}
case service.SoraStorageTypeLocal:
if h.mediaStorage != nil && len(localPaths) > 0 {
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func getUserIDFromContext(c *gin.Context) int64 {
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return subject.UserID
}
if id, ok := c.Get("user_id"); ok {
switch v := id.(type) {
case int64:
return v
case float64:
return int64(v)
case string:
n, _ := strconv.ParseInt(v, 10, 64)
return n
}
}
// 尝试从 JWT claims 获取
if id, ok := c.Get("userID"); ok {
if v, ok := id.(int64); ok {
return v
}
}
return 0
}
func groupIDForLog(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
func trimForLog(raw string, maxLen int) string {
trimmed := strings.TrimSpace(raw)
if maxLen <= 0 || len(trimmed) <= maxLen {
return trimmed
}
return trimmed[:maxLen] + "...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func (h *SoraClientHandler) GetModels(c *gin.Context) {
families := h.getModelFamilies(c.Request.Context())
response.Success(c, families)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
// 读锁检查缓存
h.modelCacheMu.RLock()
ttl := modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
families := h.cachedFamilies
h.modelCacheMu.RUnlock()
return families
}
h.modelCacheMu.RUnlock()
// 写锁更新缓存
h.modelCacheMu.Lock()
defer h.modelCacheMu.Unlock()
// double-check
ttl = modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
return h.cachedFamilies
}
// 尝试从上游获取
families, err := h.fetchUpstreamModels(ctx)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
families = service.BuildSoraModelFamilies()
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = false
return families
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = true
return families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
if h.gatewayService == nil {
return nil, fmt.Errorf("gatewayService 未初始化")
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
// 选择一个 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
if err != nil {
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
}
// 仅支持 API Key 类型账号
if account.Type != service.AccountTypeAPIKey {
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
}
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return nil, fmt.Errorf("账号缺少 api_key")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return nil, fmt.Errorf("账号缺少 base_url")
}
// 构建上游模型列表请求
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求上游失败: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 解析 OpenAI 格式的模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(modelsResp.Data) == 0 {
return nil, fmt.Errorf("上游返回空模型列表")
}
// 提取模型 ID
modelIDs := make([]string, 0, len(modelsResp.Data))
for _, m := range modelsResp.Data {
modelIDs = append(modelIDs, m.ID)
}
// 转换为模型家族
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
if len(families) == 0 {
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
}
return families, nil
}
//go:build unit
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
// ==================== Stub: SoraGenerationRepository ====================
var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
type stubSoraGenRepo struct {
gens map[int64]*service.SoraGeneration
nextID int64
createErr error
getErr error
updateErr error
deleteErr error
listErr error
countErr error
countValue int64
// 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
updateCallCount *int32
updateFailAfterN int32
// 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
getByIDCallCount int32
getByIDOverrideAfterN int32 // 0 = 不覆盖
getByIDOverrideStatus string
}
func newStubSoraGenRepo() *stubSoraGenRepo {
return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
}
func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
if r.createErr != nil {
return r.createErr
}
gen.ID = r.nextID
r.nextID++
r.gens[gen.ID] = gen
return nil
}
func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
if r.getErr != nil {
return nil, r.getErr
}
gen, ok := r.gens[id]
if !ok {
return nil, fmt.Errorf("not found")
}
// 条件性状态覆盖:模拟外部取消等场景
if r.getByIDOverrideAfterN > 0 {
n := atomic.AddInt32(&r.getByIDCallCount, 1)
if n > r.getByIDOverrideAfterN {
cp := *gen
cp.Status = r.getByIDOverrideStatus
return &cp, nil
}
}
return gen, nil
}
func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
// 条件性失败:前 N 次成功,之后失败
if r.updateCallCount != nil {
n := atomic.AddInt32(r.updateCallCount, 1)
if n > r.updateFailAfterN {
return fmt.Errorf("conditional update error (call #%d)", n)
}
}
if r.updateErr != nil {
return r.updateErr
}
r.gens[gen.ID] = gen
return nil
}
func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
if r.deleteErr != nil {
return r.deleteErr
}
delete(r.gens, id)
return nil
}
func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
if r.listErr != nil {
return nil, 0, r.listErr
}
var result []*service.SoraGeneration
for _, gen := range r.gens {
if gen.UserID != params.UserID {
continue
}
result = append(result, gen)
}
return result, int64(len(result)), nil
}
func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
return r.countValue, nil
}
// ==================== 辅助函数 ====================
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
genService := service.NewSoraGenerationService(repo, nil, nil)
return &SoraClientHandler{genService: genService}
}
func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
if body != "" {
c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
} else {
c.Request = httptest.NewRequest(method, path, nil)
}
if userID > 0 {
c.Set("user_id", userID)
}
return c, rec
}
func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
t.Helper()
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
return resp
}
// ==================== 纯函数测试: buildAsyncRequestBody ====================
func TestBuildAsyncRequestBody(t *testing.T) {
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "sora2-landscape-10s", parsed["model"])
require.Equal(t, false, parsed["stream"])
msgs := parsed["messages"].([]any)
require.Len(t, msgs, 1)
msg := msgs[0].(map[string]any)
require.Equal(t, "user", msg["role"])
require.Equal(t, "一只猫在跳舞", msg["content"])
}
func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
body := buildAsyncRequestBody("gpt-image", "", "", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "gpt-image", parsed["model"])
msgs := parsed["messages"].([]any)
msg := msgs[0].(map[string]any)
require.Equal(t, "", msg["content"])
}
func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
}
func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, float64(3), parsed["video_count"])
}
func TestNormalizeVideoCount(t *testing.T) {
require.Equal(t, 1, normalizeVideoCount("video", 0))
require.Equal(t, 2, normalizeVideoCount("video", 2))
require.Equal(t, 3, normalizeVideoCount("video", 5))
require.Equal(t, 1, normalizeVideoCount("image", 3))
}
// ==================== 纯函数测试: parseMediaURLsFromBody ====================
func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
}
func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
}
func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody(nil))
require.Nil(t, parseMediaURLsFromBody([]byte{}))
}
func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
}
func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
}
func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
}
func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
}
func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
urls := parseMediaURLsFromBody([]byte(body))
require.Len(t, urls, 2)
require.Equal(t, "https://multi.com/a.mp4", urls[0])
}
func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
}
func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
}
func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
// media_urls 不是 string 数组
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
}
func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
}
// ==================== 纯函数测试: extractMediaURLsFromResult ====================
func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(result, recorder)
require.Equal(t, "https://oauth.com/video.mp4", url)
require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
}
func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
recorder := httptest.NewRecorder()
_, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
url, urls := extractMediaURLsFromResult(result, recorder)
require.Equal(t, "https://body.com/1.mp4", url)
require.Len(t, urls, 2)
}
func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
recorder := httptest.NewRecorder()
_, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
url, urls := extractMediaURLsFromResult(nil, recorder)
require.Equal(t, "https://upstream.com/video.mp4", url)
require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
}
func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(nil, recorder)
require.Empty(t, url)
require.Nil(t, urls)
}
func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
result := &service.ForwardResult{MediaURL: ""}
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(result, recorder)
require.Empty(t, url)
require.Nil(t, urls)
}
// ==================== getUserIDFromContext ====================
func TestGetUserIDFromContext_Int64(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", int64(42))
require.Equal(t, int64(42), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
require.Equal(t, int64(777), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_Float64(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", float64(99))
require.Equal(t, int64(99), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_String(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", "123")
require.Equal(t, int64(123), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("userID", int64(55))
require.Equal(t, int64(55), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_NoID(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
require.Equal(t, int64(0), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_InvalidString(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", "not-a-number")
require.Equal(t, int64(0), getUserIDFromContext(c))
}
// ==================== Handler: Generate ====================
func TestGenerate_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
h.Generate(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGenerate_BadRequest_MissingModel(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_TooManyRequests(t *testing.T) {
repo := newStubSoraGenRepo()
repo.countValue = 3
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
func TestGenerate_CountError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.countErr = fmt.Errorf("db error")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGenerate_Success(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.NotZero(t, data["generation_id"])
require.Equal(t, "pending", data["status"])
}
func TestGenerate_DefaultMediaType(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "video", repo.gens[1].MediaType)
}
func TestGenerate_ImageMediaType(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "image", repo.gens[1].MediaType)
}
func TestGenerate_CreatePendingError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.createErr = fmt.Errorf("create failed")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGenerate_APIKeyInContext(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
c.Set("api_key_id", int64(42))
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_NoAPIKeyInContext(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_ConcurrencyBoundary(t *testing.T) {
// activeCount == 2 应该允许
repo := newStubSoraGenRepo()
repo.countValue = 2
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== Handler: ListGenerations ====================
func TestListGenerations_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
h.ListGenerations(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestListGenerations_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
repo.nextID = 3
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
items := data["data"].([]any)
require.Len(t, items, 2)
require.Equal(t, float64(2), data["total"])
}
func TestListGenerations_ListError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.listErr = fmt.Errorf("db error")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestListGenerations_DefaultPagination(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
// 不传分页参数,应默认 page=1 page_size=20
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, float64(1), data["page"])
}
// ==================== Handler: GetGeneration ====================
func TestGetGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGetGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.GetGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGetGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.GetGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestGetGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestGetGeneration_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, float64(1), data["id"])
}
// ==================== Handler: DeleteGeneration ====================
func TestDeleteGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestDeleteGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDeleteGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestDeleteGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestDeleteGeneration_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
_, exists := repo.gens[1]
require.False(t, exists)
}
// ==================== Handler: CancelGeneration ====================
func TestCancelGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestCancelGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestCancelGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestCancelGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestCancelGeneration_Pending(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "cancelled", repo.gens[1].Status)
}
func TestCancelGeneration_Generating(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "cancelled", repo.gens[1].Status)
}
func TestCancelGeneration_Completed(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestCancelGeneration_Failed(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestCancelGeneration_Cancelled(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
// ==================== Handler: GetQuota ====================
func TestGetQuota_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
h.GetQuota(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGetQuota_NilQuotaService(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
h.GetQuota(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, "unlimited", data["source"])
}
// ==================== Handler: GetModels ====================
func TestGetModels(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
h.GetModels(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].([]any)
require.Len(t, data, 4)
// 验证类型分布
videoCount, imageCount := 0, 0
for _, item := range data {
m := item.(map[string]any)
if m["type"] == "video" {
videoCount++
} else if m["type"] == "image" {
imageCount++
}
}
require.Equal(t, 3, videoCount)
require.Equal(t, 1, imageCount)
}
// ==================== Handler: GetStorageStatus ====================
func TestGetStorageStatus_NilS3(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, false, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
require.Equal(t, false, data["local_enabled"])
}
func TestGetStorageStatus_LocalEnabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, false, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
require.Equal(t, true, data["local_enabled"])
}
// ==================== Handler: SaveToStorage ====================
func TestSaveToStorage_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestSaveToStorage_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestSaveToStorage_NotUpstream(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_S3Nil(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
}
func TestSaveToStorage_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
// ==================== storeMediaWithDegradation — nil guard 路径 ====================
func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
h := &SoraClientHandler{}
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://upstream.com/v.mp4", url)
require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
require.Nil(t, keys)
require.Equal(t, int64(0), size)
}
func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
h := &SoraClientHandler{}
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://a.com/1.mp4", url)
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
require.Nil(t, keys)
require.Equal(t, int64(0), size)
}
func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
h := &SoraClientHandler{}
url, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://upstream.com/v.mp4", url)
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
type stubUserRepoForHandler struct {
users map[int64]*service.User
updateErr error
}
func newStubUserRepoForHandler() *stubUserRepoForHandler {
return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
}
func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
if u, ok := r.users[id]; ok {
return u, nil
}
return nil, fmt.Errorf("user not found")
}
func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
if r.updateErr != nil {
return r.updateErr
}
r.users[user.ID] = user
return nil
}
func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
return nil, nil
}
func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
return nil, nil
}
func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== NewSoraClientHandler ====================
func TestNewSoraClientHandler(t *testing.T) {
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, h)
}
func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, h)
require.Nil(t, h.apiKeyService)
}
// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
type stubAPIKeyRepoForHandler struct {
keys map[int64]*service.APIKey
getErr error
}
func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
}
func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
if r.getErr != nil {
return nil, r.getErr
}
if k, ok := r.keys[id]; ok {
return k, nil
}
return nil, fmt.Errorf("api key not found: %d", id)
}
func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
return "", 0, nil
}
func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) {
var updated int64
for id, key := range r.keys {
if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID {
continue
}
clone := *key
gid := newGroupID
clone.GroupID = &gid
r.keys[id] = &clone
updated++
}
return updated, nil
}
func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
return nil
}
func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
// newTestAPIKeyService 创建测试用的 APIKeyService
func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
}
// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
// 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
groupID := int64(5)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: &groupID,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.NotZero(t, data["generation_id"])
// 验证 api_key_id 已关联到生成记录
gen := repo.gens[1]
require.NotNil(t, gen.APIKeyID)
require.Equal(t, int64(42), *gen.APIKeyID)
}
func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
// 前端传递不存在的 api_key_id → 400
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
}
func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
// 前端传递别人的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 999, // 属于 user 999
Status: service.StatusAPIKeyActive,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
}
func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
// 前端传递已禁用的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyDisabled,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
}
func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
// 前端传递配额耗尽的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyQuotaExhausted,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
// 前端传递已过期的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyExpired,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
// apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService} // apiKeyService = nil
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
// api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: nil, // 无分组
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
// 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
// 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
groupID := int64(10)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: &groupID,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// 应使用 body 中的 api_key_id=42,而不是 context 中的 99
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
// 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
c.Set("api_key_id", int64(99))
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// 应使用 context 中的 api_key_id=99
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
}
func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
// JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
// JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
// api_key_id=0 不存在 → 400
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
// groupID 不为 nil → 不设置 ForcePlatform
// gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
gid := int64(5)
h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
// groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
// 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
require.Equal(t, "cancelled", repo.gens[1].Status)
}
// ==================== GenerateRequest JSON 解析 ====================
func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
// 验证 api_key_id 在 JSON 中正确解析为 *int64
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
require.NoError(t, err)
require.NotNil(t, req.APIKeyID)
require.Equal(t, int64(42), *req.APIKeyID)
}
func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
// 不传 api_key_id → 解析后为 nil
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
require.NoError(t, err)
require.Nil(t, req.APIKeyID)
}
func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
// api_key_id: null → 解析后为 nil
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
require.NoError(t, err)
require.Nil(t, req.APIKeyID)
}
func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
// 全字段解析
var req GenerateRequest
err := json.Unmarshal([]byte(`{
"model":"sora2-landscape-10s",
"prompt":"test prompt",
"media_type":"video",
"video_count":2,
"image_input":"data:image/png;base64,abc",
"api_key_id":100
}`), &req)
require.NoError(t, err)
require.Equal(t, "sora2-landscape-10s", req.Model)
require.Equal(t, "test prompt", req.Prompt)
require.Equal(t, "video", req.MediaType)
require.Equal(t, 2, req.VideoCount)
require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
require.NotNil(t, req.APIKeyID)
require.Equal(t, int64(100), *req.APIKeyID)
}
func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
// api_key_id 为 nil 时 JSON 序列化应省略
req := GenerateRequest{Model: "sora2", Prompt: "test"}
b, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(b, &parsed))
_, hasAPIKeyID := parsed["api_key_id"]
require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
}
func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
// api_key_id 不为 nil 时 JSON 序列化应包含
id := int64(42)
req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
b, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(b, &parsed))
require.Equal(t, float64(42), parsed["api_key_id"])
}
// ==================== GetQuota: 有配额服务 ====================
func TestGetQuota_WithQuotaService_Success(t *testing.T) {
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 3 * 1024 * 1024,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
h.GetQuota(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, "user", data["source"])
require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
require.Equal(t, float64(3*1024*1024), data["used_bytes"])
}
func TestGetQuota_WithQuotaService_Error(t *testing.T) {
// 用户不存在时 GetQuota 返回错误
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
h.GetQuota(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== Generate: 配额检查 ====================
func TestGenerate_QuotaCheckFailed(t *testing.T) {
// 配额超限时返回 429
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 1024,
SoraStorageUsedBytes: 1025, // 已超限
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
func TestGenerate_QuotaCheckPassed(t *testing.T) {
// 配额充足时允许生成
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
type stubSettingRepoForHandler struct {
values map[string]string
}
func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
if values == nil {
values = make(map[string]string)
}
return &stubSettingRepoForHandler{values: values}
}
func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
if v, ok := r.values[key]; ok {
return &service.Setting{Key: key, Value: v}, nil
}
return nil, service.ErrSettingNotFound
}
func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
if v, ok := r.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
r.values[key] = value
return nil
}
func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string)
for _, k := range keys {
if v, ok := r.values[k]; ok {
result[k] = v
}
}
return result, nil
}
func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
for k, v := range settings {
r.values[k] = v
}
return nil
}
func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
return r.values, nil
}
func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
delete(r.values, key)
return nil
}
// ==================== S3 / MediaStorage 辅助函数 ====================
// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
settingRepo := newStubSettingRepoForHandler(map[string]string{
"sora_s3_enabled": "true",
"sora_s3_endpoint": endpoint,
"sora_s3_region": "us-east-1",
"sora_s3_bucket": "test-bucket",
"sora_s3_access_key_id": "AKIATEST",
"sora_s3_secret_access_key": "test-secret",
"sora_s3_prefix": "sora",
"sora_s3_force_path_style": "true",
})
settingService := service.NewSettingService(settingRepo, &config.Config{})
return service.NewSoraS3Storage(settingService)
}
// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
func newFakeSourceServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/mp4")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("fake video data for test"))
}))
}
// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
func newFakeS3Server(mode string) *httptest.Server {
var counter atomic.Int32
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
_ = r.Body.Close()
switch mode {
case "ok":
w.Header().Set("ETag", `"test-etag"`)
w.WriteHeader(http.StatusOK)
case "fail":
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`))
case "fail-second":
n := counter.Add(1)
if n <= 1 {
w.Header().Set("ETag", `"test-etag"`)
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`))
}
}
}))
}
// ==================== processGeneration 直接调用测试 ====================
func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
repo.updateErr = fmt.Errorf("db error")
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
// 直接调用(非 goroutine),MarkGenerating 失败 → 早退
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
// repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
// 因此 ErrorMessage 为空(证明未调用 MarkFailed)
require.Equal(t, "generating", repo.gens[1].Status)
require.Empty(t, repo.gens[1].ErrorMessage)
}
func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
// gatewayService 未设置 → MarkFailed
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
// ==================== storeMediaWithDegradation: S3 路径 ====================
func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeS3, storageType)
require.Len(t, s3Keys, 1)
require.NotEmpty(t, s3Keys[0])
require.Len(t, storedURLs, 1)
require.Equal(t, storedURL, storedURLs[0])
require.Contains(t, storedURL, fakeS3.URL)
require.Contains(t, storedURL, "/test-bucket/")
require.Greater(t, fileSize, int64(0))
}
func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
)
require.Equal(t, service.SoraStorageTypeS3, storageType)
require.Len(t, s3Keys, 2)
require.Len(t, storedURLs, 2)
require.Equal(t, storedURL, storedURLs[0])
require.Contains(t, storedURLs[0], fakeS3.URL)
require.Contains(t, storedURLs[1], fakeS3.URL)
require.Greater(t, fileSize, int64(0))
}
func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
// 上游返回 404 → 下载失败 → S3 上传不会开始
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer badSource.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
}
func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
// S3 失败,降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Nil(t, s3Keys)
}
func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
)
// 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Nil(t, s3Keys)
}
// ==================== storeMediaWithDegradation: 本地存储路径 ====================
func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
// 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: "/dev/null/invalid_dir",
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
)
// 本地存储失败,降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
}
func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
DownloadTimeoutSeconds: 5,
MaxDownloadBytes: 10 * 1024 * 1024,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeLocal, storageType)
require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
}
func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
DownloadTimeoutSeconds: 5,
MaxDownloadBytes: 10 * 1024 * 1024,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{
s3Storage: s3Storage,
mediaStorage: mediaStorage,
}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
// S3 失败 → 本地存储成功
require.Equal(t, service.SoraStorageTypeLocal, storageType)
}
// ==================== SaveToStorage: S3 路径 ====================
func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "S3")
}
func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer expiredServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: expiredServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusGone, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "过期")
}
func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Contains(t, data["message"], "S3")
require.NotEmpty(t, data["object_key"])
// 验证记录已更新为 S3 存储
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
}
func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{
sourceServer.URL + "/v1.mp4",
sourceServer.URL + "/v2.mp4",
},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Len(t, data["object_keys"].([]any), 2)
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
require.Len(t, repo.gens[1].S3ObjectKeys, 2)
require.Len(t, repo.gens[1].MediaURLs, 2)
}
func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 100 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证配额已累加
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
}
func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== GetStorageStatus: S3 路径 ====================
func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
// S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, true, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
}
func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, true, data["s3_enabled"])
require.Equal(t, true, data["s3_healthy"])
}
// ==================== Stub: AccountRepository (用于 GatewayService) ====================
var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
type stubAccountRepoForHandler struct {
accounts []service.Account
}
func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, fmt.Errorf("account not found")
}
func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
return false, nil
}
func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
return nil
}
func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
return nil
}
func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
return nil
}
func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
return nil
}
func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
return nil
}
func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
return nil
}
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
return nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
type stubSoraClientForHandler struct {
videoStatus *service.SoraVideoTaskStatus
}
func (s *stubSoraClientForHandler) Enabled() bool { return true }
func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
return s.videoStatus, nil
}
// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
)
}
// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
}
// ==================== processGeneration: 更多路径测试 ====================
func TestProcessGeneration_SelectAccountError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
}
func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// 提供可用账号使 SelectAccountForModel 成功
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// soraGatewayService 为 nil
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
}
func TestProcessGeneration_ForwardError(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// SoraClient 返回视频任务失败
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "failed",
ErrorMsg: "content policy violation",
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
}
func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
// 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
repo.getByIDOverrideAfterN = 1
repo.getByIDOverrideStatus = "cancelled"
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
require.Equal(t, "generating", repo.gens[1].Status)
}
func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// SoraClient 返回 completed 但无 URL
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: nil, // 无 URL
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
}
func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
// 第 2 次返回 "cancelled" 状态,模拟外部取消
repo.getByIDOverrideAfterN = 1
repo.getByIDOverrideStatus = "cancelled"
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
require.Equal(t, "generating", repo.gens[1].Status)
}
func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
// 无 S3 和本地存储,降级到 upstream
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "completed", repo.gens[1].Status)
require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
require.NotEmpty(t, repo.gens[1].MediaURL)
}
func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{sourceServer.URL + "/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
s3Storage := newS3StorageForHandler(fakeS3.URL)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
s3Storage: s3Storage,
quotaService: quotaService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "completed", repo.gens[1].Status)
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
// 验证配额已累加
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
}
func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
repo.updateCallCount = new(int32)
repo.updateFailAfterN = 1
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
// MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
// 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
// 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
require.Equal(t, "completed", repo.gens[1].Status)
}
// ==================== cleanupStoredMedia 直接测试 ====================
func TestCleanupStoredMedia_S3Path(t *testing.T) {
// S3 清理路径:s3Storage 为 nil 时不 panic
h := &SoraClientHandler{}
// 不应 panic
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
func TestCleanupStoredMedia_LocalPath(t *testing.T) {
// 本地清理路径:mediaStorage 为 nil 时不 panic
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
}
func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
// upstream 类型不清理
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
}
func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
// 空 keys 不触发清理
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
}
// ==================== DeleteGeneration: 本地存储清理路径 ====================
func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "video/test.mp4",
MediaURLs: []string{"video/test.mp4"},
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
// MediaURLs 为空,使用 MediaURL 作为清理路径
tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "video/test.mp4",
MediaURLs: nil, // 空
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
// 非本地存储类型 → 跳过清理
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeUpstream,
MediaURL: "https://upstream.com/v.mp4",
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDeleteGeneration_DeleteError(t *testing.T) {
// repo.Delete 出错
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
repo.deleteErr = fmt.Errorf("delete failed")
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
// ==================== fetchUpstreamModels 测试 ====================
func TestFetchUpstreamModels_NilGateway(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
h := &SoraClientHandler{}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "gatewayService 未初始化")
}
func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "选择 Sora 账号失败")
}
func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "不支持模型同步")
}
func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"base_url": "https://sora.test"}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "api_key")
}
func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test"}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
}
func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "状态码 500")
}
func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not json"))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "解析响应失败")
}
func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "空模型列表")
}
func TestFetchUpstreamModels_Success(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求头
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
families, err := h.fetchUpstreamModels(context.Background())
require.NoError(t, err)
require.NotEmpty(t, families)
}
func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "未能从上游模型列表中识别")
}
// ==================== getModelFamilies 缓存测试 ====================
func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
// gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
h := &SoraClientHandler{}
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
// 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
families2 := h.getModelFamilies(context.Background())
require.Equal(t, families, families2)
require.False(t, h.modelCacheUpstream)
}
func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
require.True(t, h.modelCacheUpstream)
// 第二次调用命中缓存
families2 := h.getModelFamilies(context.Background())
require.Equal(t, families, families2)
}
func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
// 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
h := &SoraClientHandler{
cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
modelCacheUpstream: false,
}
// gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
// 缓存已刷新,不再是 "old"
found := false
for _, f := range families {
if f.ID == "old" {
found = true
}
}
require.False(t, found, "过期缓存应被刷新")
}
// ==================== processGeneration: groupID 与 ForcePlatform ====================
func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
// groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// 空账号列表 → SelectAccountForModel 失败
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
}
// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
// quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
// ==================== Generate: CreatePending 并发限制错误 ====================
// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
type stubSoraGenRepoWithAtomicCreate struct {
stubSoraGenRepo
limitErr error
}
func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
if r.limitErr != nil {
return r.limitErr
}
return r.stubSoraGenRepo.Create(context.Background(), gen)
}
func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
repo := &stubSoraGenRepoWithAtomicCreate{
stubSoraGenRepo: *newStubSoraGenRepo(),
limitErr: service.ErrSoraGenerationConcurrencyLimit,
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "3")
}
// ==================== SaveToStorage: 配额超限 ====================
func TestSaveToStorage_QuotaExceeded(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户配额已满
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10,
SoraStorageUsedBytes: 10,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== SaveToStorage: MediaURLs 全为空 ====================
func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: "",
MediaURLs: []string{},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "已过期")
}
// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 100 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
}
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
}
// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "nonexistent/video.mp4",
MediaURLs: []string{"nonexistent/video.mp4"},
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== CancelGeneration: 任务已结束冲突 ====================
func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
package handler
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
usageRecordWorkerPool *service.UsageRecordWorkerPool
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
soraTLSEnabled := true
signKey := ""
mediaRoot := "/app/data/sora"
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
usageRecordWorkerPool: usageRecordWorkerPool,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.sora_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
var err error
body, err = sjson.SetBytes(body, "stream", true)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
setOpsRequestContext(c, reqModel, clientStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forced
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
if platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
return
}
streamStarted := false
subscription, _ := middleware2.GetSubscriptionFromContext(c)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := generateOpenAISessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverBody []byte
var lastFailoverHeaders http.Header
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
if err != nil {
reqLog.Warn("sora.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int("last_upstream_status", lastFailoverStatus),
}
if rayID != "" {
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("last_upstream_content_type", contentType))
}
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
proxyBound := account.ProxyID != nil
proxyID := int64(0)
if account.ProxyID != nil {
proxyID = *account.ProxyID
}
tlsFingerprintEnabled := h.soraTLSEnabled
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("sora.account_wait_counter_increment_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
} else if !canWait {
reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
clientStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("sora.account_slot_acquire_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_switching", fields...)
continue
}
reqLog.Error("sora.forward_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
return
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("sora.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("switch_count", switchCount),
)
return
}
}
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Any("panic", recovered),
).Error("sora.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func cloneHTTPHeaders(headers http.Header) http.Header {
if headers == nil {
return nil
}
return headers.Clone()
}
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
if headers != nil {
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
contentType = strings.TrimSpace(headers.Get("content-type"))
if contentType == "" {
contentType = strings.TrimSpace(headers.Get("Content-Type"))
}
}
rayID = soraerror.ExtractCloudflareRayID(headers, body)
return rayID, mitigated, contentType
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
message = strings.TrimSpace(message)
if message == "" {
return false
}
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
lower := strings.ToLower(message)
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
return false
}
}
return true
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// MediaProxy serves local Sora media files.
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned serves local Sora media files with signature verification.
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
c.Status(http.StatusNotFound)
return
}
query := c.Request.URL.Query()
if requireSignature {
if h.soraMediaSigningKey == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体签名未配置",
},
})
return
}
expiresStr := strings.TrimSpace(query.Get("expires"))
signature := strings.TrimSpace(query.Get("sig"))
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil || expires <= time.Now().Unix() {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名已过期",
},
})
return
}
query.Del("sig")
query.Del("expires")
signingQuery := query.Encode()
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名无效",
},
})
return
}
}
if strings.TrimSpace(h.soraMediaRoot) == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体目录未配置",
},
})
return
}
relative := strings.TrimPrefix(cleaned, "/")
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
if _, err := os.Stat(localPath); err != nil {
if os.IsNotExist(err) {
c.Status(http.StatusNotFound)
return
}
c.Status(http.StatusInternalServerError)
return
}
c.File(localPath)
}
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
var _ service.SoraClient = (*stubSoraClient)(nil)
var _ service.AccountRepository = (*stubAccountRepo)(nil)
var _ service.GroupRepository = (*stubGroupRepo)(nil)
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
type stubSoraClient struct {
imageURLs []string
}
func (s *stubSoraClient) Enabled() bool { return true }
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
return "upload", nil
}
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
type stubAccountRepo struct {
accounts map[int64]*service.Account
}
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if acc, ok := r.accounts[id]; ok {
return acc, nil
}
return nil, service.ErrAccountNotFound
}
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
var result []*service.Account
for _, id := range ids {
if acc, ok := r.accounts[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
_, ok := r.accounts[id]
return ok, nil
}
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return map[string]int64{}, nil
}
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
var result []service.Account
for _, acc := range r.accounts {
for _, platform := range platforms {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
break
}
}
}
return result, nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
type stubGroupRepo struct {
group *service.Group
}
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, nil
}
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
type stubUsageLogRepo struct{}
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return true, nil
}
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, nil
}
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, nil
}
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
RunMode: config.RunModeSimple,
Gateway: config.GatewayConfig{
SoraStreamMode: "force",
MaxAccountSwitches: 1,
Scheduling: config.GatewaySchedulingConfig{
LoadBatchEnabled: false,
},
},
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.test",
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
groupRepo := &stubGroupRepo{group: group}
usageLogRepo := &stubUsageLogRepo{}
deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
t.Cleanup(func() {
billingCacheService.Stop()
})
gatewayService := service.NewGatewayService(
accountRepo,
groupRepo,
usageLogRepo,
nil,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,
concurrencyService,
billingService,
nil,
billingCacheService,
nil,
nil,
deferredService,
nil,
testutil.StubSessionLimitCache{},
nil, // rpmCache
nil, // digestStore
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
apiKey := &service.APIKey{
ID: 1,
UserID: 1,
Status: service.StatusActive,
GroupID: &group.ID,
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
handler.ChatCompletions(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp["media_url"])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func TestSoraHandler_StreamForcing(t *testing.T) {
// 测试 1:stream=false 时 sjson 强制修改为 true
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
clientStream := gjson.GetBytes(body, "stream").Bool()
require.False(t, clientStream)
newBody, err := sjson.SetBytes(body, "stream", true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
// 测试 2:stream=true 时不修改
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
require.True(t, gjson.GetBytes(body2, "stream").Bool())
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
require.False(t, gjson.GetBytes(body3, "stream").Bool())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func TestSoraHandler_ValidationExtraction(t *testing.T) {
// model 缺失
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
// model 为数字 → 类型不是 gjson.String,应被拒绝
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
modelResult1b := gjson.GetBytes(body1b, "model")
require.True(t, modelResult1b.Exists())
require.NotEqual(t, gjson.String, modelResult1b.Type)
// messages 缺失
body2 := []byte(`{"model":"sora"}`)
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
// messages 不是 JSON 数组(字符串)
body3 := []byte(`{"model":"sora","messages":"not array"}`)
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
// messages 是对象而非数组 → IsArray 返回 false
body4 := []byte(`{"model":"sora","messages":{}}`)
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
body5 := []byte(`{"model":"sora","messages":[]}`)
msgsResult := gjson.GetBytes(body5, "messages")
require.True(t, msgsResult.IsArray())
require.Equal(t, 0, len(msgsResult.Array()))
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
gin.SetMode(gin.TestMode)
// 从 body 提取 prompt_cache_key
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
hash := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash)
// 无 prompt_cache_key 且无 header → 空 hash
body2 := []byte(`{"model":"sora"}`)
hash2 := generateOpenAISessionHash(c, body2)
require.Empty(t, hash2)
// header 优先于 body
c.Request.Header.Set("session_id", "from-header")
hash3 := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
}
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号",
errType: "upstream_error",
message: `upstream returned "invalid" payload`,
},
{
name: "包含换行和制表符",
errType: "rate_limit_error",
message: "line1\nline2\ttab",
},
{
name: "包含反斜杠",
errType: "upstream_error",
message: `path C:\Users\test\file.txt not found`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "JSON 中应包含 error 对象")
require.Equal(t, tt.errType, errorObj["type"])
require.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"))
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
}
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
headers := http.Header{}
headers.Set("cf-mitigated", "challenge")
headers.Set("content-type", "text/html")
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
require.Equal(t, "9cff2d62d83bb98d", rayID)
require.Equal(t, "challenge", mitigated)
require.Equal(t, "text/html", contentType)
}
...@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere ...@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
}) })
require.True(t, called.Load(), "panic 后后续任务应仍可执行") require.True(t, called.Load(), "panic 后后续任务应仍可执行")
} }
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
...@@ -86,8 +86,6 @@ func ProvideHandlers( ...@@ -86,8 +86,6 @@ func ProvideHandlers(
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler, totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator, _ *service.IdempotencyCoordinator,
...@@ -104,8 +102,6 @@ func ProvideHandlers( ...@@ -104,8 +102,6 @@ func ProvideHandlers(
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
SoraClient: soraClientHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler, Totp: totpHandler,
} }
...@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet( ...@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler, NewAnnouncementHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
NewTotpHandler, NewTotpHandler,
ProvideSettingHandler, ProvideSettingHandler,
......
...@@ -17,8 +17,6 @@ import ( ...@@ -17,8 +17,6 @@ import (
const ( const (
// OAuth Client ID for OpenAI (Codex CLI official) // OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints // OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize" AuthorizeURL = "https://auth.openai.com/oauth/authorize"
...@@ -39,8 +37,6 @@ const ( ...@@ -39,8 +37,6 @@ const (
const ( const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai" OAuthPlatformOpenAI = "openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora = "sora"
) )
// OAuthSession stores OAuth flow state for OpenAI // OAuthSession stores OAuth flow state for OpenAI
...@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor ...@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
} }
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. // OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
switch strings.ToLower(strings.TrimSpace(platform)) { return ClientID, true
case OAuthPlatformSora:
return ClientID, false
default:
return ClientID, true
}
} }
// TokenRequest represents the token exchange request body // TokenRequest represents the token exchange request body
......
...@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { ...@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
} }
} }
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "" {
t.Fatalf("codex flow should be empty for sora, got=%q", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
...@@ -1692,20 +1692,13 @@ func itoa(v int) string { ...@@ -1692,20 +1692,13 @@ func itoa(v int) string {
} }
// FindByExtraField 根据 extra 字段中的键值对查找账号。 // FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 // 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
// //
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field. // FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). // Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query(). accounts, err := r.client.Account.Query().
Where( Where(
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(), dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) { func(s *entsql.Selector) {
path := sqljson.Path(key) path := sqljson.Path(key)
......
...@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k, group.FieldImagePrice1k,
group.FieldImagePrice2k, group.FieldImagePrice2k,
group.FieldImagePrice4k, group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly, group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID, group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest, group.FieldFallbackGroupIDOnInvalidRequest,
...@@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User { ...@@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpSecretEncrypted: u.TotpSecretEncrypted, TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled, TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt, TotpEnabledAt: u.TotpEnabledAt,
...@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k, ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k, ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k, ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
......
...@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject). SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
...@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject). SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
......
...@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { ...@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
} }
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id" const customClientID = "custom-client-id"
var seenClientIDs []string var seenClientIDs []string
...@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { ...@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
} }
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID wantClientID := "custom-exchange-client-id"
errCh := make(chan string, 1) errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm() _ = r.ParseForm()
......
package repository
import (
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type soraAccountRepository struct {
sql *sql.DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
return &soraAccountRepository{sql: sqlDB}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
accessToken, accessOK := updates["access_token"].(string)
refreshToken, refreshOK := updates["refresh_token"].(string)
sessionToken, sessionOK := updates["session_token"].(string)
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
if !sessionOK {
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
}
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`, accountID, sessionToken)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
}
return nil
}
_, err := r.sql.ExecContext(ctx, `
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`, accountID, accessToken, refreshToken, sessionToken)
return err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, nil // 记录不存在
}
var sa service.SoraAccount
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
return nil, err
}
return &sa, nil
}
// Delete 删除 Sora 账号扩展信息
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
_, err := r.sql.ExecContext(ctx, `
DELETE FROM sora_accounts WHERE account_id = $1
`, accountID)
return err
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
...@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
...@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
...@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount ...@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil return nil
} }
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
} }
......
...@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet( ...@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository, NewAPIKeyRepository,
NewGroupRepository, NewGroupRepository,
NewAccountRepository, NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储 NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储 NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository, NewProxyRepository,
......
...@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool { ...@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") || return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses") strings.HasPrefix(path, "/responses")
} }
......
...@@ -109,7 +109,6 @@ func registerRoutes( ...@@ -109,7 +109,6 @@ func registerRoutes(
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
} }
...@@ -34,8 +34,6 @@ func RegisterAdminRoutes( ...@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth // Gemini OAuth
registerGeminiOAuthRoutes(admin, h) registerGeminiOAuthRoutes(admin, h)
...@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini") gemini := admin.Group("/gemini")
{ {
......
...@@ -23,11 +23,6 @@ func RegisterGatewayRoutes( ...@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID() clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware() endpointNorm := handler.InboundEndpointMiddleware()
...@@ -163,28 +158,6 @@ func RegisterGatewayRoutes( ...@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
} }
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
} }
// getGroupPlatform extracts the group platform from the API Key stored in context. // getGroupPlatform extracts the group platform from the API Key stored in context.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment