Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
eb2dce92
Commit
eb2dce92
authored
Apr 06, 2026
by
陈曦
Browse files
升级v1.0.8 解决冲突
parents
7b83d6e7
339d906e
Changes
178
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/dto/settings.go
View file @
eb2dce92
...
...
@@ -61,7 +61,6 @@ type SystemSettings struct {
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
CustomMenuItems
[]
CustomMenuItem
`json:"custom_menu_items"`
CustomEndpoints
[]
CustomEndpoint
`json:"custom_endpoints"`
...
...
@@ -128,49 +127,10 @@ type PublicSettings struct {
CustomMenuItems
[]
CustomMenuItem
`json:"custom_menu_items"`
CustomEndpoints
[]
CustomEndpoint
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
BackendModeEnabled
bool
`json:"backend_mode_enabled"`
Version
string
`json:"version"`
}
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
type
SoraS3Settings
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
type
SoraS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
type
ListSoraS3ProfilesResponse
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
SoraS3Profile
`json:"items"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
type
OverloadCooldownSettings
struct
{
Enabled
bool
`json:"enabled"`
...
...
backend/internal/handler/dto/types.go
View file @
eb2dce92
...
...
@@ -26,9 +26,7 @@ type AdminUser struct {
Notes
string
`json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates
map
[
int64
]
float64
`json:"group_rates,omitempty"`
SoraStorageQuotaBytes
int64
`json:"sora_storage_quota_bytes"`
SoraStorageUsedBytes
int64
`json:"sora_storage_used_bytes"`
GroupRates
map
[
int64
]
float64
`json:"group_rates,omitempty"`
}
type
APIKey
struct
{
...
...
@@ -84,21 +82,12 @@ type Group struct {
ImagePrice2K
*
float64
`json:"image_price_2k"`
ImagePrice4K
*
float64
`json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360
*
float64
`json:"sora_image_price_360"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
// Sora 存储配额
SoraStorageQuotaBytes
int64
`json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
AllowMessagesDispatch
bool
`json:"allow_messages_dispatch"`
...
...
backend/internal/handler/endpoint.go
View file @
eb2dce92
...
...
@@ -31,7 +31,7 @@ const (
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint maps a raw request path (which may carry
// prefixes like /antigravity, /openai
, /sora
) to its canonical form.
// prefixes like /antigravity, /openai) to its canonical form.
//
// "/antigravity/v1/messages" → "/v1/messages"
// "/v1/chat/completions" → "/v1/chat/completions"
...
...
@@ -61,7 +61,7 @@ func NormalizeInboundEndpoint(path string) string {
// such as /v1/responses/compact preserved from the raw URL).
// - Anthropic → /v1/messages
// - Gemini → /v1beta/models
// -
Sora → /v1/chat/completions
// -
Antigravity → /v1/messages (Claude) or gemini (Gemini)
// - Antigravity routes may target either Claude or Gemini, so the
// inbound endpoint is used to distinguish.
func
DeriveUpstreamEndpoint
(
inbound
,
rawRequestPath
,
platform
string
)
string
{
...
...
@@ -82,9 +82,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
case
service
.
PlatformGemini
:
return
EndpointGeminiModels
case
service
.
PlatformSora
:
return
EndpointChatCompletions
case
service
.
PlatformAntigravity
:
// Antigravity accounts serve both Claude and Gemini.
if
inbound
==
EndpointGeminiModels
{
...
...
backend/internal/handler/endpoint_test.go
View file @
eb2dce92
...
...
@@ -27,11 +27,10 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
{
"/v1/responses"
,
EndpointResponses
},
{
"/v1beta/models"
,
EndpointGeminiModels
},
// Prefixed paths (antigravity, openai
, sora
).
// Prefixed paths (antigravity, openai).
{
"/antigravity/v1/messages"
,
EndpointMessages
},
{
"/openai/v1/responses"
,
EndpointResponses
},
{
"/openai/v1/responses/compact"
,
EndpointResponses
},
{
"/sora/v1/chat/completions"
,
EndpointChatCompletions
},
{
"/antigravity/v1beta/models/gemini:generateContent"
,
EndpointGeminiModels
},
// Gin route patterns with wildcards.
...
...
@@ -68,9 +67,6 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
// Gemini.
{
"gemini models"
,
EndpointGeminiModels
,
"/v1beta/models/gemini:gen"
,
service
.
PlatformGemini
,
EndpointGeminiModels
},
// Sora.
{
"sora completions"
,
EndpointChatCompletions
,
"/sora/v1/chat/completions"
,
service
.
PlatformSora
,
EndpointChatCompletions
},
// OpenAI — always /v1/responses.
{
"openai responses root"
,
EndpointResponses
,
"/v1/responses"
,
service
.
PlatformOpenAI
,
EndpointResponses
},
{
"openai responses compact"
,
EndpointResponses
,
"/openai/v1/responses/compact"
,
service
.
PlatformOpenAI
,
"/v1/responses/compact"
},
...
...
backend/internal/handler/gateway_handler.go
View file @
eb2dce92
...
...
@@ -859,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform
=
forcedPlatform
}
if
platform
==
service
.
PlatformSora
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"object"
:
"list"
,
"data"
:
service
.
DefaultSoraModels
(
h
.
cfg
),
})
return
}
// Get available models from account configurations (without platform filter)
availableModels
:=
h
.
gatewayService
.
GetAvailableModels
(
c
.
Request
.
Context
(),
groupID
,
""
)
...
...
backend/internal/handler/handler.go
View file @
eb2dce92
...
...
@@ -45,8 +45,6 @@ type Handlers struct {
Admin
*
AdminHandlers
Gateway
*
GatewayHandler
OpenAIGateway
*
OpenAIGatewayHandler
SoraGateway
*
SoraGatewayHandler
SoraClient
*
SoraClientHandler
Setting
*
SettingHandler
Totp
*
TotpHandler
}
...
...
backend/internal/handler/setting_handler.go
View file @
eb2dce92
...
...
@@ -54,7 +54,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems
:
dto
.
ParseUserVisibleMenuItems
(
settings
.
CustomMenuItems
),
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
SoraClientEnabled
:
settings
.
SoraClientEnabled
,
BackendModeEnabled
:
settings
.
BackendModeEnabled
,
Version
:
h
.
version
,
})
...
...
backend/internal/handler/sora_client_handler.go
deleted
100644 → 0
View file @
7b83d6e7
package
handler
import
(
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const
(
// 上游模型缓存 TTL
modelCacheTTL
=
1
*
time
.
Hour
// 上游获取成功
modelCacheFailedTTL
=
2
*
time
.
Minute
// 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type
SoraClientHandler
struct
{
genService
*
service
.
SoraGenerationService
quotaService
*
service
.
SoraQuotaService
s3Storage
*
service
.
SoraS3Storage
soraGatewayService
*
service
.
SoraGatewayService
gatewayService
*
service
.
GatewayService
mediaStorage
*
service
.
SoraMediaStorage
apiKeyService
*
service
.
APIKeyService
// 上游模型缓存
modelCacheMu
sync
.
RWMutex
cachedFamilies
[]
service
.
SoraModelFamily
modelCacheTime
time
.
Time
modelCacheUpstream
bool
// 是否来自上游(决定 TTL)
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func
NewSoraClientHandler
(
genService
*
service
.
SoraGenerationService
,
quotaService
*
service
.
SoraQuotaService
,
s3Storage
*
service
.
SoraS3Storage
,
soraGatewayService
*
service
.
SoraGatewayService
,
gatewayService
*
service
.
GatewayService
,
mediaStorage
*
service
.
SoraMediaStorage
,
apiKeyService
*
service
.
APIKeyService
,
)
*
SoraClientHandler
{
return
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
s3Storage
:
s3Storage
,
soraGatewayService
:
soraGatewayService
,
gatewayService
:
gatewayService
,
mediaStorage
:
mediaStorage
,
apiKeyService
:
apiKeyService
,
}
}
// GenerateRequest 生成请求。
type
GenerateRequest
struct
{
Model
string
`json:"model" binding:"required"`
Prompt
string
`json:"prompt" binding:"required"`
MediaType
string
`json:"media_type"`
// video / image,默认 video
VideoCount
int
`json:"video_count,omitempty"`
// 视频数量(1-3)
ImageInput
string
`json:"image_input,omitempty"`
// 参考图(base64 或 URL)
APIKeyID
*
int64
`json:"api_key_id,omitempty"`
// 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func
(
h
*
SoraClientHandler
)
Generate
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
var
req
GenerateRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"参数错误: "
+
err
.
Error
())
return
}
if
req
.
MediaType
==
""
{
req
.
MediaType
=
"video"
}
req
.
VideoCount
=
normalizeVideoCount
(
req
.
MediaType
,
req
.
VideoCount
)
// 并发数检查(最多 3 个)
activeCount
,
err
:=
h
.
genService
.
CountActiveByUser
(
c
.
Request
.
Context
(),
userID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
activeCount
>=
3
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"同时进行中的任务不能超过 3 个"
)
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if
h
.
quotaService
!=
nil
{
if
err
:=
h
.
quotaService
.
CheckQuota
(
c
.
Request
.
Context
(),
userID
,
0
);
err
!=
nil
{
var
quotaErr
*
service
.
QuotaExceededError
if
errors
.
As
(
err
,
&
quotaErr
)
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"存储配额已满,请删除不需要的作品释放空间"
)
return
}
response
.
Error
(
c
,
http
.
StatusForbidden
,
err
.
Error
())
return
}
}
// 获取 API Key ID 和 Group ID
var
apiKeyID
*
int64
var
groupID
*
int64
if
req
.
APIKeyID
!=
nil
&&
h
.
apiKeyService
!=
nil
{
// 前端传递了 api_key_id,需要校验
apiKey
,
err
:=
h
.
apiKeyService
.
GetByID
(
c
.
Request
.
Context
(),
*
req
.
APIKeyID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"API Key 不存在"
)
return
}
if
apiKey
.
UserID
!=
userID
{
response
.
Error
(
c
,
http
.
StatusForbidden
,
"API Key 不属于当前用户"
)
return
}
if
apiKey
.
Status
!=
service
.
StatusAPIKeyActive
{
response
.
Error
(
c
,
http
.
StatusForbidden
,
"API Key 不可用"
)
return
}
apiKeyID
=
&
apiKey
.
ID
groupID
=
apiKey
.
GroupID
}
else
if
id
,
ok
:=
c
.
Get
(
"api_key_id"
);
ok
{
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if
v
,
ok
:=
id
.
(
int64
);
ok
{
apiKeyID
=
&
v
}
}
gen
,
err
:=
h
.
genService
.
CreatePending
(
c
.
Request
.
Context
(),
userID
,
apiKeyID
,
req
.
Model
,
req
.
Prompt
,
req
.
MediaType
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationConcurrencyLimit
)
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"同时进行中的任务不能超过 3 个"
)
return
}
response
.
ErrorFrom
(
c
,
err
)
return
}
// 启动后台异步生成 goroutine
go
h
.
processGeneration
(
gen
.
ID
,
userID
,
groupID
,
req
.
Model
,
req
.
Prompt
,
req
.
MediaType
,
req
.
ImageInput
,
req
.
VideoCount
)
response
.
Success
(
c
,
gin
.
H
{
"generation_id"
:
gen
.
ID
,
"status"
:
gen
.
Status
,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
func
(
h
*
SoraClientHandler
)
processGeneration
(
genID
int64
,
userID
int64
,
groupID
*
int64
,
model
,
prompt
,
mediaType
,
imageInput
string
,
videoCount
int
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Minute
)
defer
cancel
()
// 标记为生成中
if
err
:=
h
.
genService
.
MarkGenerating
(
ctx
,
genID
,
""
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationStateConflict
)
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 任务状态已变化,跳过生成 id=%d"
,
genID
)
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 标记生成中失败 id=%d err=%v"
,
genID
,
err
)
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d"
,
genID
,
userID
,
groupIDForLog
(
groupID
),
model
,
mediaType
,
videoCount
,
strings
.
TrimSpace
(
imageInput
)
!=
""
,
len
(
strings
.
TrimSpace
(
prompt
)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if
groupID
==
nil
{
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
service
.
PlatformSora
)
}
if
h
.
gatewayService
==
nil
{
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"内部错误: gatewayService 未初始化"
)
return
}
// 选择 Sora 账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
ctx
,
groupID
,
""
,
model
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v"
,
genID
,
userID
,
groupIDForLog
(
groupID
),
model
,
err
,
)
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"选择账号失败: "
+
err
.
Error
())
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s"
,
genID
,
userID
,
groupIDForLog
(
groupID
),
model
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
account
.
Type
,
)
// 构建 chat completions 请求体(非流式)
body
:=
buildAsyncRequestBody
(
model
,
prompt
,
imageInput
,
normalizeVideoCount
(
mediaType
,
videoCount
))
if
h
.
soraGatewayService
==
nil
{
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"内部错误: soraGatewayService 未初始化"
)
return
}
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
recorder
:=
httptest
.
NewRecorder
()
mockGinCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
mockGinCtx
.
Request
,
_
=
http
.
NewRequest
(
"POST"
,
"/"
,
nil
)
// 调用 Forward(非流式)
result
,
err
:=
h
.
soraGatewayService
.
Forward
(
ctx
,
mockGinCtx
,
account
,
body
,
false
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v"
,
genID
,
account
.
ID
,
model
,
recorder
.
Code
,
trimForLog
(
recorder
.
Body
.
String
(),
400
),
err
,
)
// 检查是否已取消
gen
,
_
:=
h
.
genService
.
GetByID
(
ctx
,
genID
,
userID
)
if
gen
!=
nil
&&
gen
.
Status
==
service
.
SoraGenStatusCancelled
{
return
}
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"生成失败: "
+
err
.
Error
())
return
}
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
mediaURL
,
mediaURLs
:=
extractMediaURLsFromResult
(
result
,
recorder
)
if
mediaURL
==
""
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s"
,
genID
,
account
.
ID
,
model
,
recorder
.
Code
,
trimForLog
(
recorder
.
Body
.
String
(),
400
),
)
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"未获取到媒体 URL"
)
return
}
// 检查任务是否已被取消
gen
,
_
:=
h
.
genService
.
GetByID
(
ctx
,
genID
,
userID
)
if
gen
!=
nil
&&
gen
.
Status
==
service
.
SoraGenStatusCancelled
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 任务已取消,跳过存储 id=%d"
,
genID
)
return
}
// 三层降级存储:S3 → 本地 → 上游临时 URL
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
:=
h
.
storeMediaWithDegradation
(
ctx
,
userID
,
mediaType
,
mediaURL
,
mediaURLs
)
usageAdded
:=
false
if
(
storageType
==
service
.
SoraStorageTypeS3
||
storageType
==
service
.
SoraStorageTypeLocal
)
&&
fileSize
>
0
&&
h
.
quotaService
!=
nil
{
if
err
:=
h
.
quotaService
.
AddUsage
(
ctx
,
userID
,
fileSize
);
err
!=
nil
{
h
.
cleanupStoredMedia
(
ctx
,
storageType
,
s3Keys
,
storedURLs
)
var
quotaErr
*
service
.
QuotaExceededError
if
errors
.
As
(
err
,
&
quotaErr
)
{
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"存储配额已满,请删除不需要的作品释放空间"
)
return
}
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"存储配额更新失败: "
+
err
.
Error
())
return
}
usageAdded
=
true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen
,
_
=
h
.
genService
.
GetByID
(
ctx
,
genID
,
userID
)
if
gen
!=
nil
&&
gen
.
Status
==
service
.
SoraGenStatusCancelled
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d"
,
genID
)
h
.
cleanupStoredMedia
(
ctx
,
storageType
,
s3Keys
,
storedURLs
)
if
usageAdded
&&
h
.
quotaService
!=
nil
{
_
=
h
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
fileSize
)
}
return
}
// 标记完成
if
err
:=
h
.
genService
.
MarkCompleted
(
ctx
,
genID
,
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationStateConflict
)
{
h
.
cleanupStoredMedia
(
ctx
,
storageType
,
s3Keys
,
storedURLs
)
if
usageAdded
&&
h
.
quotaService
!=
nil
{
_
=
h
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
fileSize
)
}
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 标记完成失败 id=%d err=%v"
,
genID
,
err
)
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 生成完成 id=%d storage=%s size=%d"
,
genID
,
storageType
,
fileSize
)
}
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
func
(
h
*
SoraClientHandler
)
storeMediaWithDegradation
(
ctx
context
.
Context
,
userID
int64
,
mediaType
string
,
mediaURL
string
,
mediaURLs
[]
string
,
)
(
storedURL
string
,
storedURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSize
int64
)
{
urls
:=
mediaURLs
if
len
(
urls
)
==
0
{
urls
=
[]
string
{
mediaURL
}
}
// 第一层:尝试 S3
if
h
.
s3Storage
!=
nil
&&
h
.
s3Storage
.
Enabled
(
ctx
)
{
keys
:=
make
([]
string
,
0
,
len
(
urls
))
var
totalSize
int64
allOK
:=
true
for
_
,
u
:=
range
urls
{
key
,
size
,
err
:=
h
.
s3Storage
.
UploadFromURL
(
ctx
,
userID
,
u
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] S3 上传失败 err=%v"
,
err
)
allOK
=
false
// 清理已上传的文件
if
len
(
keys
)
>
0
{
_
=
h
.
s3Storage
.
DeleteObjects
(
ctx
,
keys
)
}
break
}
keys
=
append
(
keys
,
key
)
totalSize
+=
size
}
if
allOK
&&
len
(
keys
)
>
0
{
accessURLs
:=
make
([]
string
,
0
,
len
(
keys
))
for
_
,
key
:=
range
keys
{
accessURL
,
err
:=
h
.
s3Storage
.
GetAccessURL
(
ctx
,
key
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 生成 S3 访问 URL 失败 err=%v"
,
err
)
_
=
h
.
s3Storage
.
DeleteObjects
(
ctx
,
keys
)
allOK
=
false
break
}
accessURLs
=
append
(
accessURLs
,
accessURL
)
}
if
allOK
&&
len
(
accessURLs
)
>
0
{
return
accessURLs
[
0
],
accessURLs
,
service
.
SoraStorageTypeS3
,
keys
,
totalSize
}
}
}
// 第二层:尝试本地存储
if
h
.
mediaStorage
!=
nil
&&
h
.
mediaStorage
.
Enabled
()
{
storedPaths
,
err
:=
h
.
mediaStorage
.
StoreFromURLs
(
ctx
,
mediaType
,
urls
)
if
err
==
nil
&&
len
(
storedPaths
)
>
0
{
firstPath
:=
storedPaths
[
0
]
totalSize
,
sizeErr
:=
h
.
mediaStorage
.
TotalSizeByRelativePaths
(
storedPaths
)
if
sizeErr
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 统计本地文件大小失败 err=%v"
,
sizeErr
)
}
return
firstPath
,
storedPaths
,
service
.
SoraStorageTypeLocal
,
nil
,
totalSize
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 本地存储失败 err=%v"
,
err
)
}
// 第三层:保留上游临时 URL
return
urls
[
0
],
urls
,
service
.
SoraStorageTypeUpstream
,
nil
,
0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func
buildAsyncRequestBody
(
model
,
prompt
,
imageInput
string
,
videoCount
int
)
[]
byte
{
body
:=
map
[
string
]
any
{
"model"
:
model
,
"messages"
:
[]
map
[
string
]
string
{
{
"role"
:
"user"
,
"content"
:
prompt
},
},
"stream"
:
false
,
}
if
imageInput
!=
""
{
body
[
"image_input"
]
=
imageInput
}
if
videoCount
>
1
{
body
[
"video_count"
]
=
videoCount
}
b
,
_
:=
json
.
Marshal
(
body
)
return
b
}
func
normalizeVideoCount
(
mediaType
string
,
videoCount
int
)
int
{
if
mediaType
!=
"video"
{
return
1
}
if
videoCount
<=
0
{
return
1
}
if
videoCount
>
3
{
return
3
}
return
videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径:ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func
extractMediaURLsFromResult
(
result
*
service
.
ForwardResult
,
recorder
*
httptest
.
ResponseRecorder
)
(
string
,
[]
string
)
{
// 优先从 ForwardResult 获取(OAuth 路径)
if
result
!=
nil
&&
result
.
MediaURL
!=
""
{
// 尝试从响应体获取完整 URL 列表
if
urls
:=
parseMediaURLsFromBody
(
recorder
.
Body
.
Bytes
());
len
(
urls
)
>
0
{
return
urls
[
0
],
urls
}
return
result
.
MediaURL
,
[]
string
{
result
.
MediaURL
}
}
// 从响应体解析(APIKey 路径)
if
urls
:=
parseMediaURLsFromBody
(
recorder
.
Body
.
Bytes
());
len
(
urls
)
>
0
{
return
urls
[
0
],
urls
}
return
""
,
nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func
parseMediaURLsFromBody
(
body
[]
byte
)
[]
string
{
if
len
(
body
)
==
0
{
return
nil
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
}
// 优先 media_urls(多图数组)
if
rawURLs
,
ok
:=
resp
[
"media_urls"
];
ok
{
if
arr
,
ok
:=
rawURLs
.
([]
any
);
ok
&&
len
(
arr
)
>
0
{
urls
:=
make
([]
string
,
0
,
len
(
arr
))
for
_
,
item
:=
range
arr
{
if
s
,
ok
:=
item
.
(
string
);
ok
&&
s
!=
""
{
urls
=
append
(
urls
,
s
)
}
}
if
len
(
urls
)
>
0
{
return
urls
}
}
}
// 回退到 media_url(单个 URL)
if
url
,
ok
:=
resp
[
"media_url"
]
.
(
string
);
ok
&&
url
!=
""
{
return
[]
string
{
url
}
}
return
nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func
(
h
*
SoraClientHandler
)
ListGenerations
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
page
,
_
:=
strconv
.
Atoi
(
c
.
DefaultQuery
(
"page"
,
"1"
))
pageSize
,
_
:=
strconv
.
Atoi
(
c
.
DefaultQuery
(
"page_size"
,
"20"
))
params
:=
service
.
SoraGenerationListParams
{
UserID
:
userID
,
Status
:
c
.
Query
(
"status"
),
StorageType
:
c
.
Query
(
"storage_type"
),
MediaType
:
c
.
Query
(
"media_type"
),
Page
:
page
,
PageSize
:
pageSize
,
}
gens
,
total
,
err
:=
h
.
genService
.
List
(
c
.
Request
.
Context
(),
params
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// 为 S3 记录动态生成预签名 URL
for
_
,
gen
:=
range
gens
{
_
=
h
.
genService
.
ResolveMediaURLs
(
c
.
Request
.
Context
(),
gen
)
}
response
.
Success
(
c
,
gin
.
H
{
"data"
:
gens
,
"total"
:
total
,
"page"
:
page
,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func
(
h
*
SoraClientHandler
)
GetGeneration
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
_
=
h
.
genService
.
ResolveMediaURLs
(
c
.
Request
.
Context
(),
gen
)
response
.
Success
(
c
,
gen
)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func
(
h
*
SoraClientHandler
)
DeleteGeneration
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if
gen
.
StorageType
==
service
.
SoraStorageTypeLocal
&&
h
.
mediaStorage
!=
nil
{
paths
:=
gen
.
MediaURLs
if
len
(
paths
)
==
0
&&
gen
.
MediaURL
!=
""
{
paths
=
[]
string
{
gen
.
MediaURL
}
}
if
err
:=
h
.
mediaStorage
.
DeleteByRelativePaths
(
paths
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 删除本地文件失败 id=%d err=%v"
,
id
,
err
)
}
}
if
err
:=
h
.
genService
.
Delete
(
c
.
Request
.
Context
(),
id
,
userID
);
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"已删除"
})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func
(
h
*
SoraClientHandler
)
GetQuota
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
if
h
.
quotaService
==
nil
{
response
.
Success
(
c
,
service
.
QuotaInfo
{
QuotaSource
:
"unlimited"
,
Source
:
"unlimited"
})
return
}
quota
,
err
:=
h
.
quotaService
.
GetQuota
(
c
.
Request
.
Context
(),
userID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
quota
)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func
(
h
*
SoraClientHandler
)
CancelGeneration
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
// 权限校验
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
_
=
gen
if
err
:=
h
.
genService
.
MarkCancelled
(
c
.
Request
.
Context
(),
id
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationNotActive
)
{
response
.
Error
(
c
,
http
.
StatusConflict
,
"任务已结束,无法取消"
)
return
}
response
.
Error
(
c
,
http
.
StatusBadRequest
,
err
.
Error
())
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"已取消"
})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func
(
h
*
SoraClientHandler
)
SaveToStorage
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
if
gen
.
StorageType
!=
service
.
SoraStorageTypeUpstream
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"仅 upstream 类型的记录可手动保存"
)
return
}
if
gen
.
MediaURL
==
""
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"媒体 URL 为空,可能已过期"
)
return
}
if
h
.
s3Storage
==
nil
||
!
h
.
s3Storage
.
Enabled
(
c
.
Request
.
Context
())
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"云存储未配置,请联系管理员"
)
return
}
sourceURLs
:=
gen
.
MediaURLs
if
len
(
sourceURLs
)
==
0
&&
gen
.
MediaURL
!=
""
{
sourceURLs
=
[]
string
{
gen
.
MediaURL
}
}
if
len
(
sourceURLs
)
==
0
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"媒体 URL 为空,可能已过期"
)
return
}
uploadedKeys
:=
make
([]
string
,
0
,
len
(
sourceURLs
))
accessURLs
:=
make
([]
string
,
0
,
len
(
sourceURLs
))
var
totalSize
int64
for
_
,
sourceURL
:=
range
sourceURLs
{
objectKey
,
fileSize
,
uploadErr
:=
h
.
s3Storage
.
UploadFromURL
(
c
.
Request
.
Context
(),
userID
,
sourceURL
)
if
uploadErr
!=
nil
{
if
len
(
uploadedKeys
)
>
0
{
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
}
var
upstreamErr
*
service
.
UpstreamDownloadError
if
errors
.
As
(
uploadErr
,
&
upstreamErr
)
&&
(
upstreamErr
.
StatusCode
==
http
.
StatusForbidden
||
upstreamErr
.
StatusCode
==
http
.
StatusNotFound
)
{
response
.
Error
(
c
,
http
.
StatusGone
,
"媒体链接已过期,无法保存"
)
return
}
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"上传到 S3 失败: "
+
uploadErr
.
Error
())
return
}
accessURL
,
err
:=
h
.
s3Storage
.
GetAccessURL
(
c
.
Request
.
Context
(),
objectKey
)
if
err
!=
nil
{
uploadedKeys
=
append
(
uploadedKeys
,
objectKey
)
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"生成 S3 访问链接失败: "
+
err
.
Error
())
return
}
uploadedKeys
=
append
(
uploadedKeys
,
objectKey
)
accessURLs
=
append
(
accessURLs
,
accessURL
)
totalSize
+=
fileSize
}
usageAdded
:=
false
if
totalSize
>
0
&&
h
.
quotaService
!=
nil
{
if
err
:=
h
.
quotaService
.
AddUsage
(
c
.
Request
.
Context
(),
userID
,
totalSize
);
err
!=
nil
{
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
var
quotaErr
*
service
.
QuotaExceededError
if
errors
.
As
(
err
,
&
quotaErr
)
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"存储配额已满,请删除不需要的作品释放空间"
)
return
}
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"配额更新失败: "
+
err
.
Error
())
return
}
usageAdded
=
true
}
if
err
:=
h
.
genService
.
UpdateStorageForCompleted
(
c
.
Request
.
Context
(),
id
,
accessURLs
[
0
],
accessURLs
,
service
.
SoraStorageTypeS3
,
uploadedKeys
,
totalSize
,
);
err
!=
nil
{
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
if
usageAdded
&&
h
.
quotaService
!=
nil
{
_
=
h
.
quotaService
.
ReleaseUsage
(
c
.
Request
.
Context
(),
userID
,
totalSize
)
}
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"已保存到 S3"
,
"object_key"
:
uploadedKeys
[
0
],
"object_keys"
:
uploadedKeys
,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func
(
h
*
SoraClientHandler
)
GetStorageStatus
(
c
*
gin
.
Context
)
{
s3Enabled
:=
h
.
s3Storage
!=
nil
&&
h
.
s3Storage
.
Enabled
(
c
.
Request
.
Context
())
s3Healthy
:=
false
if
s3Enabled
{
s3Healthy
=
h
.
s3Storage
.
IsHealthy
(
c
.
Request
.
Context
())
}
localEnabled
:=
h
.
mediaStorage
!=
nil
&&
h
.
mediaStorage
.
Enabled
()
response
.
Success
(
c
,
gin
.
H
{
"s3_enabled"
:
s3Enabled
,
"s3_healthy"
:
s3Healthy
,
"local_enabled"
:
localEnabled
,
})
}
func
(
h
*
SoraClientHandler
)
cleanupStoredMedia
(
ctx
context
.
Context
,
storageType
string
,
s3Keys
[]
string
,
localPaths
[]
string
)
{
switch
storageType
{
case
service
.
SoraStorageTypeS3
:
if
h
.
s3Storage
!=
nil
&&
len
(
s3Keys
)
>
0
{
if
err
:=
h
.
s3Storage
.
DeleteObjects
(
ctx
,
s3Keys
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 清理 S3 文件失败 keys=%v err=%v"
,
s3Keys
,
err
)
}
}
case
service
.
SoraStorageTypeLocal
:
if
h
.
mediaStorage
!=
nil
&&
len
(
localPaths
)
>
0
{
if
err
:=
h
.
mediaStorage
.
DeleteByRelativePaths
(
localPaths
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 清理本地文件失败 paths=%v err=%v"
,
localPaths
,
err
)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func
getUserIDFromContext
(
c
*
gin
.
Context
)
int64
{
if
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
);
ok
&&
subject
.
UserID
>
0
{
return
subject
.
UserID
}
if
id
,
ok
:=
c
.
Get
(
"user_id"
);
ok
{
switch
v
:=
id
.
(
type
)
{
case
int64
:
return
v
case
float64
:
return
int64
(
v
)
case
string
:
n
,
_
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
return
n
}
}
// 尝试从 JWT claims 获取
if
id
,
ok
:=
c
.
Get
(
"userID"
);
ok
{
if
v
,
ok
:=
id
.
(
int64
);
ok
{
return
v
}
}
return
0
}
func
groupIDForLog
(
groupID
*
int64
)
int64
{
if
groupID
==
nil
{
return
0
}
return
*
groupID
}
func
trimForLog
(
raw
string
,
maxLen
int
)
string
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
maxLen
<=
0
||
len
(
trimmed
)
<=
maxLen
{
return
trimmed
}
return
trimmed
[
:
maxLen
]
+
"...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func
(
h
*
SoraClientHandler
)
GetModels
(
c
*
gin
.
Context
)
{
families
:=
h
.
getModelFamilies
(
c
.
Request
.
Context
())
response
.
Success
(
c
,
families
)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func
(
h
*
SoraClientHandler
)
getModelFamilies
(
ctx
context
.
Context
)
[]
service
.
SoraModelFamily
{
// 读锁检查缓存
h
.
modelCacheMu
.
RLock
()
ttl
:=
modelCacheTTL
if
!
h
.
modelCacheUpstream
{
ttl
=
modelCacheFailedTTL
}
if
h
.
cachedFamilies
!=
nil
&&
time
.
Since
(
h
.
modelCacheTime
)
<
ttl
{
families
:=
h
.
cachedFamilies
h
.
modelCacheMu
.
RUnlock
()
return
families
}
h
.
modelCacheMu
.
RUnlock
()
// 写锁更新缓存
h
.
modelCacheMu
.
Lock
()
defer
h
.
modelCacheMu
.
Unlock
()
// double-check
ttl
=
modelCacheTTL
if
!
h
.
modelCacheUpstream
{
ttl
=
modelCacheFailedTTL
}
if
h
.
cachedFamilies
!=
nil
&&
time
.
Since
(
h
.
modelCacheTime
)
<
ttl
{
return
h
.
cachedFamilies
}
// 尝试从上游获取
families
,
err
:=
h
.
fetchUpstreamModels
(
ctx
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 上游模型获取失败,使用本地配置: %v"
,
err
)
families
=
service
.
BuildSoraModelFamilies
()
h
.
cachedFamilies
=
families
h
.
modelCacheTime
=
time
.
Now
()
h
.
modelCacheUpstream
=
false
return
families
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 从上游同步到 %d 个模型家族"
,
len
(
families
))
h
.
cachedFamilies
=
families
h
.
modelCacheTime
=
time
.
Now
()
h
.
modelCacheUpstream
=
true
return
families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func
(
h
*
SoraClientHandler
)
fetchUpstreamModels
(
ctx
context
.
Context
)
([]
service
.
SoraModelFamily
,
error
)
{
if
h
.
gatewayService
==
nil
{
return
nil
,
fmt
.
Errorf
(
"gatewayService 未初始化"
)
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
service
.
PlatformSora
)
// 选择一个 Sora 账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
ctx
,
nil
,
""
,
"sora2-landscape-10s"
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"选择 Sora 账号失败: %w"
,
err
)
}
// 仅支持 API Key 类型账号
if
account
.
Type
!=
service
.
AccountTypeAPIKey
{
return
nil
,
fmt
.
Errorf
(
"当前账号类型 %s 不支持模型同步"
,
account
.
Type
)
}
apiKey
:=
account
.
GetCredential
(
"api_key"
)
if
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"账号缺少 api_key"
)
}
baseURL
:=
account
.
GetBaseURL
()
if
baseURL
==
""
{
return
nil
,
fmt
.
Errorf
(
"账号缺少 base_url"
)
}
// 构建上游模型列表请求
modelsURL
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
"/sora/v1/models"
reqCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
reqCtx
,
http
.
MethodGet
,
modelsURL
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
client
:=
&
http
.
Client
{
Timeout
:
10
*
time
.
Second
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求上游失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"上游返回状态码 %d"
,
resp
.
StatusCode
)
}
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
1
*
1024
*
1024
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
// 解析 OpenAI 格式的模型列表
var
modelsResp
struct
{
Data
[]
struct
{
ID
string
`json:"id"`
}
`json:"data"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
modelsResp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"解析响应失败: %w"
,
err
)
}
if
len
(
modelsResp
.
Data
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"上游返回空模型列表"
)
}
// 提取模型 ID
modelIDs
:=
make
([]
string
,
0
,
len
(
modelsResp
.
Data
))
for
_
,
m
:=
range
modelsResp
.
Data
{
modelIDs
=
append
(
modelIDs
,
m
.
ID
)
}
// 转换为模型家族
families
:=
service
.
BuildSoraModelFamiliesFromIDs
(
modelIDs
)
if
len
(
families
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"未能从上游模型列表中识别出有效的模型家族"
)
}
return
families
,
nil
}
backend/internal/handler/sora_client_handler_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
handler
import
(
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
init
()
{
gin
.
SetMode
(
gin
.
TestMode
)
}
// ==================== Stub: SoraGenerationRepository ====================
var
_
service
.
SoraGenerationRepository
=
(
*
stubSoraGenRepo
)(
nil
)
type
stubSoraGenRepo
struct
{
gens
map
[
int64
]
*
service
.
SoraGeneration
nextID
int64
createErr
error
getErr
error
updateErr
error
deleteErr
error
listErr
error
countErr
error
countValue
int64
// 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
updateCallCount
*
int32
updateFailAfterN
int32
// 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
getByIDCallCount
int32
getByIDOverrideAfterN
int32
// 0 = 不覆盖
getByIDOverrideStatus
string
}
func
newStubSoraGenRepo
()
*
stubSoraGenRepo
{
return
&
stubSoraGenRepo
{
gens
:
make
(
map
[
int64
]
*
service
.
SoraGeneration
),
nextID
:
1
}
}
func
(
r
*
stubSoraGenRepo
)
Create
(
_
context
.
Context
,
gen
*
service
.
SoraGeneration
)
error
{
if
r
.
createErr
!=
nil
{
return
r
.
createErr
}
gen
.
ID
=
r
.
nextID
r
.
nextID
++
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubSoraGenRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
SoraGeneration
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
gen
,
ok
:=
r
.
gens
[
id
]
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"not found"
)
}
// 条件性状态覆盖:模拟外部取消等场景
if
r
.
getByIDOverrideAfterN
>
0
{
n
:=
atomic
.
AddInt32
(
&
r
.
getByIDCallCount
,
1
)
if
n
>
r
.
getByIDOverrideAfterN
{
cp
:=
*
gen
cp
.
Status
=
r
.
getByIDOverrideStatus
return
&
cp
,
nil
}
}
return
gen
,
nil
}
func
(
r
*
stubSoraGenRepo
)
Update
(
_
context
.
Context
,
gen
*
service
.
SoraGeneration
)
error
{
// 条件性失败:前 N 次成功,之后失败
if
r
.
updateCallCount
!=
nil
{
n
:=
atomic
.
AddInt32
(
r
.
updateCallCount
,
1
)
if
n
>
r
.
updateFailAfterN
{
return
fmt
.
Errorf
(
"conditional update error (call #%d)"
,
n
)
}
}
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubSoraGenRepo
)
Delete
(
_
context
.
Context
,
id
int64
)
error
{
if
r
.
deleteErr
!=
nil
{
return
r
.
deleteErr
}
delete
(
r
.
gens
,
id
)
return
nil
}
func
(
r
*
stubSoraGenRepo
)
List
(
_
context
.
Context
,
params
service
.
SoraGenerationListParams
)
([]
*
service
.
SoraGeneration
,
int64
,
error
)
{
if
r
.
listErr
!=
nil
{
return
nil
,
0
,
r
.
listErr
}
var
result
[]
*
service
.
SoraGeneration
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
!=
params
.
UserID
{
continue
}
result
=
append
(
result
,
gen
)
}
return
result
,
int64
(
len
(
result
)),
nil
}
func
(
r
*
stubSoraGenRepo
)
CountByUserAndStatus
(
_
context
.
Context
,
_
int64
,
_
[]
string
)
(
int64
,
error
)
{
if
r
.
countErr
!=
nil
{
return
0
,
r
.
countErr
}
return
r
.
countValue
,
nil
}
// ==================== 辅助函数 ====================
func
newTestSoraClientHandler
(
repo
*
stubSoraGenRepo
)
*
SoraClientHandler
{
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
return
&
SoraClientHandler
{
genService
:
genService
}
}
func
makeGinContext
(
method
,
path
,
body
string
,
userID
int64
)
(
*
gin
.
Context
,
*
httptest
.
ResponseRecorder
)
{
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
if
body
!=
""
{
c
.
Request
=
httptest
.
NewRequest
(
method
,
path
,
strings
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
}
else
{
c
.
Request
=
httptest
.
NewRequest
(
method
,
path
,
nil
)
}
if
userID
>
0
{
c
.
Set
(
"user_id"
,
userID
)
}
return
c
,
rec
}
func
parseResponse
(
t
*
testing
.
T
,
rec
*
httptest
.
ResponseRecorder
)
map
[
string
]
any
{
t
.
Helper
()
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
return
resp
}
// ==================== 纯函数测试: buildAsyncRequestBody ====================
func
TestBuildAsyncRequestBody
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"sora2-landscape-10s"
,
"一只猫在跳舞"
,
""
,
1
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
parsed
[
"model"
])
require
.
Equal
(
t
,
false
,
parsed
[
"stream"
])
msgs
:=
parsed
[
"messages"
]
.
([]
any
)
require
.
Len
(
t
,
msgs
,
1
)
msg
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"user"
,
msg
[
"role"
])
require
.
Equal
(
t
,
"一只猫在跳舞"
,
msg
[
"content"
])
}
func
TestBuildAsyncRequestBody_EmptyPrompt
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"gpt-image"
,
""
,
""
,
1
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
"gpt-image"
,
parsed
[
"model"
])
msgs
:=
parsed
[
"messages"
]
.
([]
any
)
msg
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
""
,
msg
[
"content"
])
}
func
TestBuildAsyncRequestBody_WithImageInput
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"gpt-image"
,
"一只猫"
,
"https://example.com/ref.png"
,
1
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
"https://example.com/ref.png"
,
parsed
[
"image_input"
])
}
func
TestBuildAsyncRequestBody_WithVideoCount
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"sora2-landscape-10s"
,
"一只猫在跳舞"
,
""
,
3
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
float64
(
3
),
parsed
[
"video_count"
])
}
func
TestNormalizeVideoCount
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
1
,
normalizeVideoCount
(
"video"
,
0
))
require
.
Equal
(
t
,
2
,
normalizeVideoCount
(
"video"
,
2
))
require
.
Equal
(
t
,
3
,
normalizeVideoCount
(
"video"
,
5
))
require
.
Equal
(
t
,
1
,
normalizeVideoCount
(
"image"
,
3
))
}
// ==================== 纯函数测试: parseMediaURLsFromBody ====================
func
TestParseMediaURLsFromBody_MediaURLs
(
t
*
testing
.
T
)
{
urls
:=
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`
))
require
.
Equal
(
t
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
urls
)
}
func
TestParseMediaURLsFromBody_SingleMediaURL
(
t
*
testing
.
T
)
{
urls
:=
parseMediaURLsFromBody
([]
byte
(
`{"media_url":"https://a.com/video.mp4"}`
))
require
.
Equal
(
t
,
[]
string
{
"https://a.com/video.mp4"
},
urls
)
}
func
TestParseMediaURLsFromBody_EmptyBody
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
(
nil
))
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
{}))
}
func
TestParseMediaURLsFromBody_InvalidJSON
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
"not json"
)))
}
func
TestParseMediaURLsFromBody_NoMediaFields
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"data":"something"}`
)))
}
func
TestParseMediaURLsFromBody_EmptyMediaURL
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_url":""}`
)))
}
func
TestParseMediaURLsFromBody_EmptyMediaURLs
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":[]}`
)))
}
func
TestParseMediaURLsFromBody_MediaURLsPriority
(
t
*
testing
.
T
)
{
body
:=
`{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
urls
:=
parseMediaURLsFromBody
([]
byte
(
body
))
require
.
Len
(
t
,
urls
,
2
)
require
.
Equal
(
t
,
"https://multi.com/a.mp4"
,
urls
[
0
])
}
func
TestParseMediaURLsFromBody_FilterEmpty
(
t
*
testing
.
T
)
{
urls
:=
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`
))
require
.
Equal
(
t
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
urls
)
}
func
TestParseMediaURLsFromBody_AllEmpty
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":["",""]}`
)))
}
func
TestParseMediaURLsFromBody_NonStringArray
(
t
*
testing
.
T
)
{
// media_urls 不是 string 数组
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":"not-array"}`
)))
}
func
TestParseMediaURLsFromBody_MediaURLNotString
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_url":123}`
)))
}
// ==================== 纯函数测试: extractMediaURLsFromResult ====================
func
TestExtractMediaURLsFromResult_OAuthPath
(
t
*
testing
.
T
)
{
result
:=
&
service
.
ForwardResult
{
MediaURL
:
"https://oauth.com/video.mp4"
}
recorder
:=
httptest
.
NewRecorder
()
url
,
urls
:=
extractMediaURLsFromResult
(
result
,
recorder
)
require
.
Equal
(
t
,
"https://oauth.com/video.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://oauth.com/video.mp4"
},
urls
)
}
func
TestExtractMediaURLsFromResult_OAuthWithBody
(
t
*
testing
.
T
)
{
result
:=
&
service
.
ForwardResult
{
MediaURL
:
"https://oauth.com/video.mp4"
}
recorder
:=
httptest
.
NewRecorder
()
_
,
_
=
recorder
.
Write
([]
byte
(
`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`
))
url
,
urls
:=
extractMediaURLsFromResult
(
result
,
recorder
)
require
.
Equal
(
t
,
"https://body.com/1.mp4"
,
url
)
require
.
Len
(
t
,
urls
,
2
)
}
func
TestExtractMediaURLsFromResult_APIKeyPath
(
t
*
testing
.
T
)
{
recorder
:=
httptest
.
NewRecorder
()
_
,
_
=
recorder
.
Write
([]
byte
(
`{"media_url":"https://upstream.com/video.mp4"}`
))
url
,
urls
:=
extractMediaURLsFromResult
(
nil
,
recorder
)
require
.
Equal
(
t
,
"https://upstream.com/video.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://upstream.com/video.mp4"
},
urls
)
}
func
TestExtractMediaURLsFromResult_NilResultEmptyBody
(
t
*
testing
.
T
)
{
recorder
:=
httptest
.
NewRecorder
()
url
,
urls
:=
extractMediaURLsFromResult
(
nil
,
recorder
)
require
.
Empty
(
t
,
url
)
require
.
Nil
(
t
,
urls
)
}
func
TestExtractMediaURLsFromResult_EmptyMediaURL
(
t
*
testing
.
T
)
{
result
:=
&
service
.
ForwardResult
{
MediaURL
:
""
}
recorder
:=
httptest
.
NewRecorder
()
url
,
urls
:=
extractMediaURLsFromResult
(
result
,
recorder
)
require
.
Empty
(
t
,
url
)
require
.
Nil
(
t
,
urls
)
}
// ==================== getUserIDFromContext ====================
func
TestGetUserIDFromContext_Int64
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
int64
(
42
))
require
.
Equal
(
t
,
int64
(
42
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_AuthSubject
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
777
})
require
.
Equal
(
t
,
int64
(
777
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_Float64
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
float64
(
99
))
require
.
Equal
(
t
,
int64
(
99
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_String
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
"123"
)
require
.
Equal
(
t
,
int64
(
123
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_UserIDFallback
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"userID"
,
int64
(
55
))
require
.
Equal
(
t
,
int64
(
55
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_NoID
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
require
.
Equal
(
t
,
int64
(
0
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_InvalidString
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
"not-a-number"
)
require
.
Equal
(
t
,
int64
(
0
),
getUserIDFromContext
(
c
))
}
// ==================== Handler: Generate ====================
func
TestGenerate_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
0
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestGenerate_BadRequest_MissingModel
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGenerate_BadRequest_MissingPrompt
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGenerate_BadRequest_InvalidJSON
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{invalid`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGenerate_TooManyRequests
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
countValue
=
3
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
}
func
TestGenerate_CountError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
countErr
=
fmt
.
Errorf
(
"db error"
)
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
func
TestGenerate_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"测试生成"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
NotZero
(
t
,
data
[
"generation_id"
])
require
.
Equal
(
t
,
"pending"
,
data
[
"status"
])
}
func
TestGenerate_DefaultMediaType
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"video"
,
repo
.
gens
[
1
]
.
MediaType
)
}
func
TestGenerate_ImageMediaType
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"gpt-image","prompt":"test","media_type":"image"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"image"
,
repo
.
gens
[
1
]
.
MediaType
)
}
func
TestGenerate_CreatePendingError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
createErr
=
fmt
.
Errorf
(
"create failed"
)
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
func
TestGenerate_NilQuotaServiceSkipsCheck
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestGenerate_APIKeyInContext
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
c
.
Set
(
"api_key_id"
,
int64
(
42
))
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_NoAPIKeyInContext
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Nil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_ConcurrencyBoundary
(
t
*
testing
.
T
)
{
// activeCount == 2 应该允许
repo
:=
newStubSoraGenRepo
()
repo
.
countValue
=
2
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
// ==================== Handler: ListGenerations ====================
func
TestListGenerations_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations"
,
""
,
0
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestListGenerations_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Model
:
"sora2-landscape-10s"
,
Status
:
"completed"
,
StorageType
:
"upstream"
}
repo
.
gens
[
2
]
=
&
service
.
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Model
:
"gpt-image"
,
Status
:
"pending"
,
StorageType
:
"none"
}
repo
.
nextID
=
3
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations?page=1&page_size=10"
,
""
,
1
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
items
:=
data
[
"data"
]
.
([]
any
)
require
.
Len
(
t
,
items
,
2
)
require
.
Equal
(
t
,
float64
(
2
),
data
[
"total"
])
}
func
TestListGenerations_ListError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
listErr
=
fmt
.
Errorf
(
"db error"
)
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations"
,
""
,
1
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
func
TestListGenerations_DefaultPagination
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
// 不传分页参数,应默认 page=1 page_size=20
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations"
,
""
,
1
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
float64
(
1
),
data
[
"page"
])
}
// ==================== Handler: GetGeneration ====================
func
TestGetGeneration_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/1"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestGetGeneration_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/abc"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGetGeneration_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/999"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestGetGeneration_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestGetGeneration_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Model
:
"sora2-landscape-10s"
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
"https://example.com/video.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
float64
(
1
),
data
[
"id"
])
}
// ==================== Handler: DeleteGeneration ====================
func
TestDeleteGeneration_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestDeleteGeneration_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/abc"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestDeleteGeneration_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/999"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestDeleteGeneration_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestDeleteGeneration_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
// ==================== Handler: CancelGeneration ====================
func
TestCancelGeneration_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestCancelGeneration_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/abc/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestCancelGeneration_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/999/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestCancelGeneration_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"pending"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestCancelGeneration_Pending
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"cancelled"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestCancelGeneration_Generating
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"generating"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"cancelled"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestCancelGeneration_Completed
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
func
TestCancelGeneration_Failed
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"failed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
func
TestCancelGeneration_Cancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"cancelled"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
// ==================== Handler: GetQuota ====================
func
TestGetQuota_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
0
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestGetQuota_NilQuotaService
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
1
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"unlimited"
,
data
[
"source"
])
}
// ==================== Handler: GetModels ====================
func
TestGetModels
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/models"
,
""
,
0
)
h
.
GetModels
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
([]
any
)
require
.
Len
(
t
,
data
,
4
)
// 验证类型分布
videoCount
,
imageCount
:=
0
,
0
for
_
,
item
:=
range
data
{
m
:=
item
.
(
map
[
string
]
any
)
if
m
[
"type"
]
==
"video"
{
videoCount
++
}
else
if
m
[
"type"
]
==
"image"
{
imageCount
++
}
}
require
.
Equal
(
t
,
3
,
videoCount
)
require
.
Equal
(
t
,
1
,
imageCount
)
}
// ==================== Handler: GetStorageStatus ====================
func
TestGetStorageStatus_NilS3
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
false
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
false
,
data
[
"s3_healthy"
])
require
.
Equal
(
t
,
false
,
data
[
"local_enabled"
])
}
func
TestGetStorageStatus_LocalEnabled
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-storage-status-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
false
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
false
,
data
[
"s3_healthy"
])
require
.
Equal
(
t
,
true
,
data
[
"local_enabled"
])
}
// ==================== Handler: SaveToStorage ====================
func
TestSaveToStorage_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestSaveToStorage_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/abc/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestSaveToStorage_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/999/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestSaveToStorage_NotUpstream
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"s3"
,
MediaURL
:
"https://example.com/v.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestSaveToStorage_EmptyMediaURL
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
""
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestSaveToStorage_S3Nil
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
"https://example.com/video.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"云存储"
)
}
func
TestSaveToStorage_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
"https://example.com/video.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
// ==================== storeMediaWithDegradation — nil guard 路径 ====================
func
TestStoreMediaWithDegradation_NilS3NilMedia
(
t
*
testing
.
T
)
{
h
:=
&
SoraClientHandler
{}
url
,
urls
,
storageType
,
keys
,
size
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Equal
(
t
,
"https://upstream.com/v.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://upstream.com/v.mp4"
},
urls
)
require
.
Nil
(
t
,
keys
)
require
.
Equal
(
t
,
int64
(
0
),
size
)
}
func
TestStoreMediaWithDegradation_NilGuardsMultiURL
(
t
*
testing
.
T
)
{
h
:=
&
SoraClientHandler
{}
url
,
urls
,
storageType
,
keys
,
size
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Equal
(
t
,
"https://a.com/1.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
urls
)
require
.
Nil
(
t
,
keys
)
require
.
Equal
(
t
,
int64
(
0
),
size
)
}
func
TestStoreMediaWithDegradation_EmptyMediaURLsFallback
(
t
*
testing
.
T
)
{
h
:=
&
SoraClientHandler
{}
url
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
[]
string
{},
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Equal
(
t
,
"https://upstream.com/v.mp4"
,
url
)
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var
_
service
.
UserRepository
=
(
*
stubUserRepoForHandler
)(
nil
)
type
stubUserRepoForHandler
struct
{
users
map
[
int64
]
*
service
.
User
updateErr
error
}
func
newStubUserRepoForHandler
()
*
stubUserRepoForHandler
{
return
&
stubUserRepoForHandler
{
users
:
make
(
map
[
int64
]
*
service
.
User
)}
}
func
(
r
*
stubUserRepoForHandler
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
if
u
,
ok
:=
r
.
users
[
id
];
ok
{
return
u
,
nil
}
return
nil
,
fmt
.
Errorf
(
"user not found"
)
}
func
(
r
*
stubUserRepoForHandler
)
Update
(
_
context
.
Context
,
user
*
service
.
User
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
users
[
user
.
ID
]
=
user
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
Create
(
context
.
Context
,
*
service
.
User
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
GetByEmail
(
context
.
Context
,
string
)
(
*
service
.
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
GetFirstAdmin
(
context
.
Context
)
(
*
service
.
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
UpdateBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
ExistsByEmail
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
// ==================== NewSoraClientHandler ====================
func
TestNewSoraClientHandler
(
t
*
testing
.
T
)
{
h
:=
NewSoraClientHandler
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
require
.
NotNil
(
t
,
h
)
}
func
TestNewSoraClientHandler_WithAPIKeyService
(
t
*
testing
.
T
)
{
h
:=
NewSoraClientHandler
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
require
.
NotNil
(
t
,
h
)
require
.
Nil
(
t
,
h
.
apiKeyService
)
}
// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
var
_
service
.
APIKeyRepository
=
(
*
stubAPIKeyRepoForHandler
)(
nil
)
type
stubAPIKeyRepoForHandler
struct
{
keys
map
[
int64
]
*
service
.
APIKey
getErr
error
}
func
newStubAPIKeyRepoForHandler
()
*
stubAPIKeyRepoForHandler
{
return
&
stubAPIKeyRepoForHandler
{
keys
:
make
(
map
[
int64
]
*
service
.
APIKey
)}
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
APIKey
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
if
k
,
ok
:=
r
.
keys
[
id
];
ok
{
return
k
,
nil
}
return
nil
,
fmt
.
Errorf
(
"api key not found: %d"
,
id
)
}
func
(
r
*
stubAPIKeyRepoForHandler
)
Create
(
context
.
Context
,
*
service
.
APIKey
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetKeyAndOwnerID
(
_
context
.
Context
,
_
int64
)
(
string
,
int64
,
error
)
{
return
""
,
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetByKey
(
context
.
Context
,
string
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetByKeyForAuth
(
context
.
Context
,
string
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
Update
(
context
.
Context
,
*
service
.
APIKey
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListByUserID
(
_
context
.
Context
,
_
int64
,
_
pagination
.
PaginationParams
,
_
service
.
APIKeyListFilters
)
([]
service
.
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
VerifyOwnership
(
context
.
Context
,
int64
,
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
CountByUserID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ExistsByKey
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListByGroupID
(
_
context
.
Context
,
_
int64
,
_
pagination
.
PaginationParams
)
([]
service
.
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
SearchAPIKeys
(
context
.
Context
,
int64
,
string
,
int
)
([]
service
.
APIKey
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
UpdateGroupIDByUserAndGroup
(
_
context
.
Context
,
userID
,
oldGroupID
,
newGroupID
int64
)
(
int64
,
error
)
{
var
updated
int64
for
id
,
key
:=
range
r
.
keys
{
if
key
.
UserID
!=
userID
||
key
.
GroupID
==
nil
||
*
key
.
GroupID
!=
oldGroupID
{
continue
}
clone
:=
*
key
gid
:=
newGroupID
clone
.
GroupID
=
&
gid
r
.
keys
[
id
]
=
&
clone
updated
++
}
return
updated
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListKeysByUserID
(
context
.
Context
,
int64
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListKeysByGroupID
(
context
.
Context
,
int64
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
IncrementQuotaUsed
(
_
context
.
Context
,
_
int64
,
_
float64
)
(
float64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
UpdateLastUsed
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
IncrementRateLimitUsage
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ResetRateLimitWindows
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetRateLimitData
(
context
.
Context
,
int64
)
(
*
service
.
APIKeyRateLimitData
,
error
)
{
return
nil
,
nil
}
// newTestAPIKeyService 创建测试用的 APIKeyService
func
newTestAPIKeyService
(
repo
*
stubAPIKeyRepoForHandler
)
*
service
.
APIKeyService
{
return
service
.
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
}
// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
func
TestGenerate_WithAPIKeyID_Success
(
t
*
testing
.
T
)
{
// 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
groupID
:=
int64
(
5
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyActive
,
GroupID
:
&
groupID
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
NotZero
(
t
,
data
[
"generation_id"
])
// 验证 api_key_id 已关联到生成记录
gen
:=
repo
.
gens
[
1
]
require
.
NotNil
(
t
,
gen
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
gen
.
APIKeyID
)
}
func
TestGenerate_WithAPIKeyID_NotFound
(
t
*
testing
.
T
)
{
// 前端传递不存在的 api_key_id → 400
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"不存在"
)
}
func
TestGenerate_WithAPIKeyID_WrongUser
(
t
*
testing
.
T
)
{
// 前端传递别人的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
999
,
// 属于 user 999
Status
:
service
.
StatusAPIKeyActive
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"不属于"
)
}
func
TestGenerate_WithAPIKeyID_Disabled
(
t
*
testing
.
T
)
{
// 前端传递已禁用的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyDisabled
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"不可用"
)
}
func
TestGenerate_WithAPIKeyID_QuotaExhausted
(
t
*
testing
.
T
)
{
// 前端传递配额耗尽的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyQuotaExhausted
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
}
func
TestGenerate_WithAPIKeyID_Expired
(
t
*
testing
.
T
)
{
// 前端传递已过期的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyExpired
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
}
func
TestGenerate_WithAPIKeyID_NilAPIKeyService
(
t
*
testing
.
T
)
{
// apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
// apiKeyService = nil
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
require
.
Nil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_WithAPIKeyID_NilGroupID
(
t
*
testing
.
T
)
{
// api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyActive
,
GroupID
:
nil
,
// 无分组
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_NoAPIKeyID_NoContext_NilResult
(
t
*
testing
.
T
)
{
// 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Nil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_WithAPIKeyIDInBody_OverridesContext
(
t
*
testing
.
T
)
{
// 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
groupID
:=
int64
(
10
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyActive
,
GroupID
:
&
groupID
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
c
.
Set
(
"api_key_id"
,
int64
(
99
))
// context 中有另一个 api_key_id
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// 应使用 body 中的 api_key_id=42,而不是 context 中的 99
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_WithContextAPIKeyID_FallbackPath
(
t
*
testing
.
T
)
{
// 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
c
.
Set
(
"api_key_id"
,
int64
(
99
))
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// 应使用 context 中的 api_key_id=99
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
99
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_APIKeyID_Zero_IgnoredInJSON
(
t
*
testing
.
T
)
{
// JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
// JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
// api_key_id=0 不存在 → 400
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
func
TestProcessGeneration_WithGroupID_NoForcePlatform
(
t
*
testing
.
T
)
{
// groupID 不为 nil → 不设置 ForcePlatform
// gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
gid
:=
int64
(
5
)
h
.
processGeneration
(
1
,
1
,
&
gid
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"gatewayService"
)
}
func
TestProcessGeneration_NilGroupID_SetsForcePlatform
(
t
*
testing
.
T
)
{
// groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"gatewayService"
)
}
func
TestProcessGeneration_MarkGeneratingStateConflict
(
t
*
testing
.
T
)
{
// 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"cancelled"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
require
.
Equal
(
t
,
"cancelled"
,
repo
.
gens
[
1
]
.
Status
)
}
// ==================== GenerateRequest JSON 解析 ====================
func
TestGenerateRequest_WithAPIKeyID_JSONParsing
(
t
*
testing
.
T
)
{
// 验证 api_key_id 在 JSON 中正确解析为 *int64
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{"model":"sora2","prompt":"test","api_key_id":42}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
req
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
req
.
APIKeyID
)
}
func
TestGenerateRequest_WithoutAPIKeyID_JSONParsing
(
t
*
testing
.
T
)
{
// 不传 api_key_id → 解析后为 nil
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{"model":"sora2","prompt":"test"}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
req
.
APIKeyID
)
}
func
TestGenerateRequest_NullAPIKeyID_JSONParsing
(
t
*
testing
.
T
)
{
// api_key_id: null → 解析后为 nil
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{"model":"sora2","prompt":"test","api_key_id":null}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
req
.
APIKeyID
)
}
func
TestGenerateRequest_FullFields_JSONParsing
(
t
*
testing
.
T
)
{
// 全字段解析
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{
"model":"sora2-landscape-10s",
"prompt":"test prompt",
"media_type":"video",
"video_count":2,
"image_input":"data:image/png;base64,abc",
"api_key_id":100
}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
req
.
Model
)
require
.
Equal
(
t
,
"test prompt"
,
req
.
Prompt
)
require
.
Equal
(
t
,
"video"
,
req
.
MediaType
)
require
.
Equal
(
t
,
2
,
req
.
VideoCount
)
require
.
Equal
(
t
,
"data:image/png;base64,abc"
,
req
.
ImageInput
)
require
.
NotNil
(
t
,
req
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
100
),
*
req
.
APIKeyID
)
}
func
TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID
(
t
*
testing
.
T
)
{
// api_key_id 为 nil 时 JSON 序列化应省略
req
:=
GenerateRequest
{
Model
:
"sora2"
,
Prompt
:
"test"
}
b
,
err
:=
json
.
Marshal
(
req
)
require
.
NoError
(
t
,
err
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
b
,
&
parsed
))
_
,
hasAPIKeyID
:=
parsed
[
"api_key_id"
]
require
.
False
(
t
,
hasAPIKeyID
,
"api_key_id 为 nil 时应省略"
)
}
func
TestGenerateRequest_JSONSerialize_IncludesAPIKeyID
(
t
*
testing
.
T
)
{
// api_key_id 不为 nil 时 JSON 序列化应包含
id
:=
int64
(
42
)
req
:=
GenerateRequest
{
Model
:
"sora2"
,
Prompt
:
"test"
,
APIKeyID
:
&
id
}
b
,
err
:=
json
.
Marshal
(
req
)
require
.
NoError
(
t
,
err
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
b
,
&
parsed
))
require
.
Equal
(
t
,
float64
(
42
),
parsed
[
"api_key_id"
])
}
// ==================== GetQuota: 有配额服务 ====================
func
TestGetQuota_WithQuotaService_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
3
*
1024
*
1024
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
1
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"user"
,
data
[
"source"
])
require
.
Equal
(
t
,
float64
(
10
*
1024
*
1024
),
data
[
"quota_bytes"
])
require
.
Equal
(
t
,
float64
(
3
*
1024
*
1024
),
data
[
"used_bytes"
])
}
func
TestGetQuota_WithQuotaService_Error
(
t
*
testing
.
T
)
{
// 用户不存在时 GetQuota 返回错误
userRepo
:=
newStubUserRepoForHandler
()
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
999
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== Generate: 配额检查 ====================
func
TestGenerate_QuotaCheckFailed
(
t
*
testing
.
T
)
{
// 配额超限时返回 429
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
1024
,
SoraStorageUsedBytes
:
1025
,
// 已超限
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
}
func
TestGenerate_QuotaCheckPassed
(
t
*
testing
.
T
)
{
// 配额充足时允许生成
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
0
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
var
_
service
.
SettingRepository
=
(
*
stubSettingRepoForHandler
)(
nil
)
type
stubSettingRepoForHandler
struct
{
values
map
[
string
]
string
}
func
newStubSettingRepoForHandler
(
values
map
[
string
]
string
)
*
stubSettingRepoForHandler
{
if
values
==
nil
{
values
=
make
(
map
[
string
]
string
)
}
return
&
stubSettingRepoForHandler
{
values
:
values
}
}
func
(
r
*
stubSettingRepoForHandler
)
Get
(
_
context
.
Context
,
key
string
)
(
*
service
.
Setting
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
&
service
.
Setting
{
Key
:
key
,
Value
:
v
},
nil
}
return
nil
,
service
.
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForHandler
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
v
,
nil
}
return
""
,
service
.
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForHandler
)
Set
(
_
context
.
Context
,
key
,
value
string
)
error
{
r
.
values
[
key
]
=
value
return
nil
}
func
(
r
*
stubSettingRepoForHandler
)
GetMultiple
(
_
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
result
:=
make
(
map
[
string
]
string
)
for
_
,
k
:=
range
keys
{
if
v
,
ok
:=
r
.
values
[
k
];
ok
{
result
[
k
]
=
v
}
}
return
result
,
nil
}
func
(
r
*
stubSettingRepoForHandler
)
SetMultiple
(
_
context
.
Context
,
settings
map
[
string
]
string
)
error
{
for
k
,
v
:=
range
settings
{
r
.
values
[
k
]
=
v
}
return
nil
}
func
(
r
*
stubSettingRepoForHandler
)
GetAll
(
_
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
return
r
.
values
,
nil
}
func
(
r
*
stubSettingRepoForHandler
)
Delete
(
_
context
.
Context
,
key
string
)
error
{
delete
(
r
.
values
,
key
)
return
nil
}
// ==================== S3 / MediaStorage 辅助函数 ====================
// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
func
newS3StorageForHandler
(
endpoint
string
)
*
service
.
SoraS3Storage
{
settingRepo
:=
newStubSettingRepoForHandler
(
map
[
string
]
string
{
"sora_s3_enabled"
:
"true"
,
"sora_s3_endpoint"
:
endpoint
,
"sora_s3_region"
:
"us-east-1"
,
"sora_s3_bucket"
:
"test-bucket"
,
"sora_s3_access_key_id"
:
"AKIATEST"
,
"sora_s3_secret_access_key"
:
"test-secret"
,
"sora_s3_prefix"
:
"sora"
,
"sora_s3_force_path_style"
:
"true"
,
})
settingService
:=
service
.
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
return
service
.
NewSoraS3Storage
(
settingService
)
}
// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
func
newFakeSourceServer
()
*
httptest
.
Server
{
return
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"video/mp4"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"fake video data for test"
))
}))
}
// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
func
newFakeS3Server
(
mode
string
)
*
httptest
.
Server
{
var
counter
atomic
.
Int32
return
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
,
_
=
io
.
Copy
(
io
.
Discard
,
r
.
Body
)
_
=
r
.
Body
.
Close
()
switch
mode
{
case
"ok"
:
w
.
Header
()
.
Set
(
"ETag"
,
`"test-etag"`
)
w
.
WriteHeader
(
http
.
StatusOK
)
case
"fail"
:
w
.
WriteHeader
(
http
.
StatusForbidden
)
_
,
_
=
w
.
Write
([]
byte
(
`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`
))
case
"fail-second"
:
n
:=
counter
.
Add
(
1
)
if
n
<=
1
{
w
.
Header
()
.
Set
(
"ETag"
,
`"test-etag"`
)
w
.
WriteHeader
(
http
.
StatusOK
)
}
else
{
w
.
WriteHeader
(
http
.
StatusForbidden
)
_
,
_
=
w
.
Write
([]
byte
(
`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`
))
}
}
}))
}
// ==================== processGeneration 直接调用测试 ====================
func
TestProcessGeneration_MarkGeneratingFails
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
repo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
// 直接调用(非 goroutine),MarkGenerating 失败 → 早退
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
// repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
// 因此 ErrorMessage 为空(证明未调用 MarkFailed)
require
.
Equal
(
t
,
"generating"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Empty
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
)
}
func
TestProcessGeneration_GatewayServiceNil
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
// gatewayService 未设置 → MarkFailed
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"gatewayService"
)
}
// ==================== storeMediaWithDegradation: S3 路径 ====================
func
TestStoreMediaWithDegradation_S3SuccessSingleURL
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
storageType
)
require
.
Len
(
t
,
s3Keys
,
1
)
require
.
NotEmpty
(
t
,
s3Keys
[
0
])
require
.
Len
(
t
,
storedURLs
,
1
)
require
.
Equal
(
t
,
storedURL
,
storedURLs
[
0
])
require
.
Contains
(
t
,
storedURL
,
fakeS3
.
URL
)
require
.
Contains
(
t
,
storedURL
,
"/test-bucket/"
)
require
.
Greater
(
t
,
fileSize
,
int64
(
0
))
}
func
TestStoreMediaWithDegradation_S3SuccessMultiURL
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
urls
:=
[]
string
{
sourceServer
.
URL
+
"/a.mp4"
,
sourceServer
.
URL
+
"/b.mp4"
}
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/a.mp4"
,
urls
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
storageType
)
require
.
Len
(
t
,
s3Keys
,
2
)
require
.
Len
(
t
,
storedURLs
,
2
)
require
.
Equal
(
t
,
storedURL
,
storedURLs
[
0
])
require
.
Contains
(
t
,
storedURLs
[
0
],
fakeS3
.
URL
)
require
.
Contains
(
t
,
storedURLs
[
1
],
fakeS3
.
URL
)
require
.
Greater
(
t
,
fileSize
,
int64
(
0
))
}
func
TestStoreMediaWithDegradation_S3DownloadFails
(
t
*
testing
.
T
)
{
// 上游返回 404 → 下载失败 → S3 上传不会开始
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
badSource
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusNotFound
)
}))
defer
badSource
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
_
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
badSource
.
URL
+
"/missing.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
}
func
TestStoreMediaWithDegradation_S3FailsSingleURL
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
_
,
_
,
storageType
,
s3Keys
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
// S3 失败,降级到 upstream
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Nil
(
t
,
s3Keys
)
}
func
TestStoreMediaWithDegradation_S3PartialFailureCleanup
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail-second"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
urls
:=
[]
string
{
sourceServer
.
URL
+
"/a.mp4"
,
sourceServer
.
URL
+
"/b.mp4"
}
_
,
_
,
storageType
,
s3Keys
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/a.mp4"
,
urls
,
)
// 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Nil
(
t
,
s3Keys
)
}
// ==================== storeMediaWithDegradation: 本地存储路径 ====================
func
TestStoreMediaWithDegradation_LocalStorageFails
(
t
*
testing
.
T
)
{
// 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
"/dev/null/invalid_dir"
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
_
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
nil
,
)
// 本地存储失败,降级到 upstream
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
}
func
TestStoreMediaWithDegradation_LocalStorageSuccess
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-handler-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
DownloadTimeoutSeconds
:
5
,
MaxDownloadBytes
:
10
*
1024
*
1024
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
_
,
_
,
storageType
,
s3Keys
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeLocal
,
storageType
)
require
.
Nil
(
t
,
s3Keys
)
// 本地存储不返回 S3 keys
}
func
TestStoreMediaWithDegradation_S3FailsFallbackToLocal
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-handler-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
DownloadTimeoutSeconds
:
5
,
MaxDownloadBytes
:
10
*
1024
*
1024
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
,
mediaStorage
:
mediaStorage
,
}
_
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
// S3 失败 → 本地存储成功
require
.
Equal
(
t
,
service
.
SoraStorageTypeLocal
,
storageType
)
}
// ==================== SaveToStorage: S3 路径 ====================
func
TestSaveToStorage_S3EnabledButUploadFails
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
resp
[
"message"
],
"S3"
)
}
func
TestSaveToStorage_UpstreamURLExpired
(
t
*
testing
.
T
)
{
expiredServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusForbidden
)
}))
defer
expiredServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
expiredServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusGone
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"过期"
)
}
func
TestSaveToStorage_S3EnabledUploadSuccess
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Contains
(
t
,
data
[
"message"
],
"S3"
)
require
.
NotEmpty
(
t
,
data
[
"object_key"
])
// 验证记录已更新为 S3 存储
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
repo
.
gens
[
1
]
.
StorageType
)
}
func
TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v1.mp4"
,
MediaURLs
:
[]
string
{
sourceServer
.
URL
+
"/v1.mp4"
,
sourceServer
.
URL
+
"/v2.mp4"
,
},
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Len
(
t
,
data
[
"object_keys"
]
.
([]
any
),
2
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
repo
.
gens
[
1
]
.
StorageType
)
require
.
Len
(
t
,
repo
.
gens
[
1
]
.
S3ObjectKeys
,
2
)
require
.
Len
(
t
,
repo
.
gens
[
1
]
.
MediaURLs
,
2
)
}
func
TestSaveToStorage_S3EnabledUploadSuccessWithQuota
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
100
*
1024
*
1024
,
SoraStorageUsedBytes
:
0
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// 验证配额已累加
require
.
Greater
(
t
,
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
,
int64
(
0
))
}
func
TestSaveToStorage_S3UploadSuccessMarkCompletedFails
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== GetStorageStatus: S3 路径 ====================
func
TestGetStorageStatus_S3EnabledNotHealthy
(
t
*
testing
.
T
)
{
// S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
true
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
false
,
data
[
"s3_healthy"
])
}
func
TestGetStorageStatus_S3EnabledHealthy
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
true
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
true
,
data
[
"s3_healthy"
])
}
// ==================== Stub: AccountRepository (用于 GatewayService) ====================
var
_
service
.
AccountRepository
=
(
*
stubAccountRepoForHandler
)(
nil
)
type
stubAccountRepoForHandler
struct
{
accounts
[]
service
.
Account
}
func
(
r
*
stubAccountRepoForHandler
)
Create
(
context
.
Context
,
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
for
i
:=
range
r
.
accounts
{
if
r
.
accounts
[
i
]
.
ID
==
id
{
return
&
r
.
accounts
[
i
],
nil
}
}
return
nil
,
fmt
.
Errorf
(
"account not found"
)
}
func
(
r
*
stubAccountRepoForHandler
)
GetByIDs
(
context
.
Context
,
[]
int64
)
([]
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ExistsByID
(
context
.
Context
,
int64
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
GetByCRSAccountID
(
context
.
Context
,
string
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
FindByExtraField
(
context
.
Context
,
string
,
any
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListCRSAccountIDs
(
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
Update
(
context
.
Context
,
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
string
,
string
,
string
,
string
,
int64
,
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListByGroup
(
context
.
Context
,
int64
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListActive
(
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListByPlatform
(
context
.
Context
,
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
UpdateLastUsed
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
BatchUpdateLastUsed
(
context
.
Context
,
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetError
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearError
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetSchedulable
(
context
.
Context
,
int64
,
bool
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
AutoPauseExpiredAccounts
(
context
.
Context
,
time
.
Time
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
BindGroups
(
context
.
Context
,
int64
,
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulable
(
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByGroupID
(
context
.
Context
,
int64
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByPlatform
(
_
context
.
Context
,
_
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByGroupIDAndPlatform
(
context
.
Context
,
int64
,
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByPlatforms
(
context
.
Context
,
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByGroupIDAndPlatforms
(
context
.
Context
,
int64
,
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableUngroupedByPlatform
(
_
context
.
Context
,
_
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableUngroupedByPlatforms
(
_
context
.
Context
,
_
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetRateLimited
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetModelRateLimit
(
context
.
Context
,
int64
,
string
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetOverloaded
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetTempUnschedulable
(
context
.
Context
,
int64
,
time
.
Time
,
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearTempUnschedulable
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearRateLimit
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearAntigravityQuotaScopes
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearModelRateLimits
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
UpdateSessionWindow
(
context
.
Context
,
int64
,
*
time
.
Time
,
*
time
.
Time
,
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
UpdateExtra
(
context
.
Context
,
int64
,
map
[
string
]
any
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
BulkUpdate
(
context
.
Context
,
[]
int64
,
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
IncrementQuotaUsed
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ResetQuotaUsed
(
context
.
Context
,
int64
)
error
{
return
nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var
_
service
.
SoraClient
=
(
*
stubSoraClientForHandler
)(
nil
)
type
stubSoraClientForHandler
struct
{
videoStatus
*
service
.
SoraVideoTaskStatus
}
func
(
s
*
stubSoraClientForHandler
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClientForHandler
)
UploadImage
(
context
.
Context
,
*
service
.
Account
,
[]
byte
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
CreateImageTask
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
CreateVideoTask
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraVideoRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
CreateStoryboardTask
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraStoryboardRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
UploadCharacterVideo
(
context
.
Context
,
*
service
.
Account
,
[]
byte
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetCameoStatus
(
context
.
Context
,
*
service
.
Account
,
string
)
(
*
service
.
SoraCameoStatus
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
DownloadCharacterImage
(
context
.
Context
,
*
service
.
Account
,
string
)
([]
byte
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
UploadCharacterImage
(
context
.
Context
,
*
service
.
Account
,
[]
byte
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
FinalizeCharacter
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
SetCharacterPublic
(
context
.
Context
,
*
service
.
Account
,
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForHandler
)
DeleteCharacter
(
context
.
Context
,
*
service
.
Account
,
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForHandler
)
PostVideoForWatermarkFree
(
context
.
Context
,
*
service
.
Account
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
DeletePost
(
context
.
Context
,
*
service
.
Account
,
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetWatermarkFreeURLCustom
(
context
.
Context
,
*
service
.
Account
,
string
,
string
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
EnhancePrompt
(
context
.
Context
,
*
service
.
Account
,
string
,
string
,
int
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetImageTask
(
context
.
Context
,
*
service
.
Account
,
string
)
(
*
service
.
SoraImageTaskStatus
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetVideoTask
(
_
context
.
Context
,
_
*
service
.
Account
,
_
string
)
(
*
service
.
SoraVideoTaskStatus
,
error
)
{
return
s
.
videoStatus
,
nil
}
// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
func
newMinimalGatewayService
(
accountRepo
service
.
AccountRepository
)
*
service
.
GatewayService
{
return
service
.
NewGatewayService
(
accountRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
}
// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
func
newMinimalSoraGatewayService
(
soraClient
service
.
SoraClient
)
*
service
.
SoraGatewayService
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
return
service
.
NewSoraGatewayService
(
soraClient
,
nil
,
nil
,
cfg
)
}
// ==================== processGeneration: 更多路径测试 ====================
func
TestProcessGeneration_SelectAccountError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
nil
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"选择账号失败"
)
}
func
TestProcessGeneration_SoraGatewayServiceNil
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 提供可用账号使 SelectAccountForModel 成功
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
// soraGatewayService 为 nil
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"soraGatewayService"
)
}
func
TestProcessGeneration_ForwardError
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
// SoraClient 返回视频任务失败
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"content policy violation"
,
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"生成失败"
)
}
func
TestProcessGeneration_ForwardErrorCancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
// MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
// 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
repo
.
getByIDOverrideAfterN
=
1
repo
.
getByIDOverrideStatus
=
"cancelled"
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"reject"
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
require
.
Equal
(
t
,
"generating"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestProcessGeneration_ForwardSuccessNoMediaURL
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
// SoraClient 返回 completed 但无 URL
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
nil
,
// 无 URL
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"未获取到媒体 URL"
)
}
func
TestProcessGeneration_ForwardSuccessCancelledBeforeStore
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
// MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
// 第 2 次返回 "cancelled" 状态,模拟外部取消
repo
.
getByIDOverrideAfterN
=
1
repo
.
getByIDOverrideStatus
=
"cancelled"
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
require
.
Equal
(
t
,
"generating"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestProcessGeneration_FullSuccessUpstream
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
// 无 S3 和本地存储,降级到 upstream
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"completed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
repo
.
gens
[
1
]
.
StorageType
)
require
.
NotEmpty
(
t
,
repo
.
gens
[
1
]
.
MediaURL
)
}
func
TestProcessGeneration_FullSuccessWithS3
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
sourceServer
.
URL
+
"/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
100
*
1024
*
1024
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"completed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
repo
.
gens
[
1
]
.
StorageType
)
require
.
NotEmpty
(
t
,
repo
.
gens
[
1
]
.
S3ObjectKeys
)
require
.
Greater
(
t
,
repo
.
gens
[
1
]
.
FileSizeBytes
,
int64
(
0
))
// 验证配额已累加
require
.
Greater
(
t
,
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
,
int64
(
0
))
}
func
TestProcessGeneration_MarkCompletedFails
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
repo
.
updateCallCount
=
new
(
int32
)
repo
.
updateFailAfterN
=
1
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
// MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
// 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
// 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
require
.
Equal
(
t
,
"completed"
,
repo
.
gens
[
1
]
.
Status
)
}
// ==================== cleanupStoredMedia 直接测试 ====================
func
TestCleanupStoredMedia_S3Path
(
t
*
testing
.
T
)
{
// S3 清理路径:s3Storage 为 nil 时不 panic
h
:=
&
SoraClientHandler
{}
// 不应 panic
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
[]
string
{
"key1"
},
nil
)
}
func
TestCleanupStoredMedia_LocalPath
(
t
*
testing
.
T
)
{
// 本地清理路径:mediaStorage 为 nil 时不 panic
h
:=
&
SoraClientHandler
{}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeLocal
,
nil
,
[]
string
{
"/tmp/test.mp4"
})
}
func
TestCleanupStoredMedia_UpstreamPath
(
t
*
testing
.
T
)
{
// upstream 类型不清理
h
:=
&
SoraClientHandler
{}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeUpstream
,
nil
,
nil
)
}
func
TestCleanupStoredMedia_EmptyKeys
(
t
*
testing
.
T
)
{
// 空 keys 不触发清理
h
:=
&
SoraClientHandler
{}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
nil
,
nil
)
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeLocal
,
nil
,
nil
)
}
// ==================== DeleteGeneration: 本地存储清理路径 ====================
func
TestDeleteGeneration_LocalStorageCleanup
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-delete-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeLocal
,
MediaURL
:
"video/test.mp4"
,
MediaURLs
:
[]
string
{
"video/test.mp4"
},
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback
(
t
*
testing
.
T
)
{
// MediaURLs 为空,使用 MediaURL 作为清理路径
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-delete-fallback-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeLocal
,
MediaURL
:
"video/test.mp4"
,
MediaURLs
:
nil
,
// 空
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestDeleteGeneration_NonLocalStorage_SkipCleanup
(
t
*
testing
.
T
)
{
// 非本地存储类型 → 跳过清理
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeUpstream
,
MediaURL
:
"https://upstream.com/v.mp4"
,
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestDeleteGeneration_DeleteError
(
t
*
testing
.
T
)
{
// repo.Delete 出错
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
}
repo
.
deleteErr
=
fmt
.
Errorf
(
"delete failed"
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
// ==================== fetchUpstreamModels 测试 ====================
func
TestFetchUpstreamModels_NilGateway
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
h
:=
&
SoraClientHandler
{}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"gatewayService 未初始化"
)
}
func
TestFetchUpstreamModels_NoAccounts
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
nil
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"选择 Sora 账号失败"
)
}
func
TestFetchUpstreamModels_NonAPIKeyAccount
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
"oauth"
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"不支持模型同步"
)
}
func
TestFetchUpstreamModels_MissingAPIKey
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://sora.test"
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"api_key"
)
}
func
TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
}
func
TestFetchUpstreamModels_UpstreamReturns500
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"状态码 500"
)
}
func
TestFetchUpstreamModels_UpstreamReturnsInvalidJSON
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"not json"
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"解析响应失败"
)
}
func
TestFetchUpstreamModels_UpstreamReturnsEmptyList
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"空模型列表"
)
}
func
TestFetchUpstreamModels_Success
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
// 验证请求头
require
.
Equal
(
t
,
"Bearer sk-test"
,
r
.
Header
.
Get
(
"Authorization"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
r
.
URL
.
Path
,
"/sora/v1/models"
))
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
families
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
families
)
}
func
TestFetchUpstreamModels_UnrecognizedModels
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"未能从上游模型列表中识别"
)
}
// ==================== getModelFamilies 缓存测试 ====================
func
TestGetModelFamilies_CachesLocalConfig
(
t
*
testing
.
T
)
{
// gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
h
:=
&
SoraClientHandler
{}
families
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
NotEmpty
(
t
,
families
)
// 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
families2
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
Equal
(
t
,
families
,
families2
)
require
.
False
(
t
,
h
.
modelCacheUpstream
)
}
func
TestGetModelFamilies_CachesUpstreamResult
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
families
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
NotEmpty
(
t
,
families
)
require
.
True
(
t
,
h
.
modelCacheUpstream
)
// 第二次调用命中缓存
families2
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
Equal
(
t
,
families
,
families2
)
}
func
TestGetModelFamilies_ExpiredCacheRefreshes
(
t
*
testing
.
T
)
{
// 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
h
:=
&
SoraClientHandler
{
cachedFamilies
:
[]
service
.
SoraModelFamily
{{
ID
:
"old"
}},
modelCacheTime
:
time
.
Now
()
.
Add
(
-
10
*
time
.
Minute
),
// 已过期
modelCacheUpstream
:
false
,
}
// gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
families
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
NotEmpty
(
t
,
families
)
// 缓存已刷新,不再是 "old"
found
:=
false
for
_
,
f
:=
range
families
{
if
f
.
ID
==
"old"
{
found
=
true
}
}
require
.
False
(
t
,
found
,
"过期缓存应被刷新"
)
}
// ==================== processGeneration: groupID 与 ForcePlatform ====================
func
TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails
(
t
*
testing
.
T
)
{
// groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 空账号列表 → SelectAccountForModel 失败
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
nil
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"选择账号失败"
)
}
// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
func
TestGenerate_CheckQuotaNonQuotaError
(
t
*
testing
.
T
)
{
// quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
userRepo
:=
newStubUserRepoForHandler
()
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
NewSoraClientHandler
(
genService
,
quotaService
,
nil
,
nil
,
nil
,
nil
,
nil
)
body
:=
`{"model":"sora2-landscape-10s","prompt":"test"}`
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
body
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
}
// ==================== Generate: CreatePending 并发限制错误 ====================
// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
type
stubSoraGenRepoWithAtomicCreate
struct
{
stubSoraGenRepo
limitErr
error
}
func
(
r
*
stubSoraGenRepoWithAtomicCreate
)
CreatePendingWithLimit
(
_
context
.
Context
,
gen
*
service
.
SoraGeneration
,
_
[]
string
,
_
int64
)
error
{
if
r
.
limitErr
!=
nil
{
return
r
.
limitErr
}
return
r
.
stubSoraGenRepo
.
Create
(
context
.
Background
(),
gen
)
}
func
TestGenerate_CreatePendingConcurrencyLimit
(
t
*
testing
.
T
)
{
repo
:=
&
stubSoraGenRepoWithAtomicCreate
{
stubSoraGenRepo
:
*
newStubSoraGenRepo
(),
limitErr
:
service
.
ErrSoraGenerationConcurrencyLimit
,
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
NewSoraClientHandler
(
genService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
body
:=
`{"model":"sora2-landscape-10s","prompt":"test"}`
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
body
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
resp
[
"message"
],
"3"
)
}
// ==================== SaveToStorage: 配额超限 ====================
func
TestSaveToStorage_QuotaExceeded
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 用户配额已满
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
,
SoraStorageUsedBytes
:
10
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
}
// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
func
TestSaveToStorage_QuotaNonQuotaError
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo
:=
newStubUserRepoForHandler
()
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== SaveToStorage: MediaURLs 全为空 ====================
func
TestSaveToStorage_EmptyMediaURLs
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
""
,
MediaURLs
:
[]
string
{},
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
resp
[
"message"
],
"已过期"
)
}
// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
func
TestSaveToStorage_MultiURL_SecondUploadFails
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail-second"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v1.mp4"
,
MediaURLs
:
[]
string
{
sourceServer
.
URL
+
"/v1.mp4"
,
sourceServer
.
URL
+
"/v2.mp4"
},
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
func
TestSaveToStorage_MarkCompletedFailsWithQuotaRollback
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
repo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
100
*
1024
*
1024
,
SoraStorageUsedBytes
:
0
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
func
TestCleanupStoredMedia_WithS3Storage_ActualDelete
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
[]
string
{
"key1"
,
"key2"
},
nil
)
}
func
TestCleanupStoredMedia_S3DeleteFails_LogOnly
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
[]
string
{
"key1"
},
nil
)
}
func
TestCleanupStoredMedia_LocalDeleteFails_LogOnly
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-cleanup-fail-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeLocal
,
nil
,
[]
string
{
"nonexistent/file.mp4"
})
}
// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
func
TestDeleteGeneration_LocalStorageDeleteFails_LogOnly
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-del-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeLocal
,
MediaURL
:
"nonexistent/video.mp4"
,
MediaURLs
:
[]
string
{
"nonexistent/video.mp4"
},
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
// ==================== CancelGeneration: 任务已结束冲突 ====================
func
TestCancelGeneration_AlreadyCompleted
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
backend/internal/handler/sora_gateway_handler.go
deleted
100644 → 0
View file @
7b83d6e7
package
handler
import
(
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type
SoraGatewayHandler
struct
{
gatewayService
*
service
.
GatewayService
soraGatewayService
*
service
.
SoraGatewayService
billingCacheService
*
service
.
BillingCacheService
usageRecordWorkerPool
*
service
.
UsageRecordWorkerPool
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
streamMode
string
soraTLSEnabled
bool
soraMediaSigningKey
string
soraMediaRoot
string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func
NewSoraGatewayHandler
(
gatewayService
*
service
.
GatewayService
,
soraGatewayService
*
service
.
SoraGatewayService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
usageRecordWorkerPool
*
service
.
UsageRecordWorkerPool
,
cfg
*
config
.
Config
,
)
*
SoraGatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
maxAccountSwitches
:=
3
streamMode
:=
"force"
soraTLSEnabled
:=
true
signKey
:=
""
mediaRoot
:=
"/app/data/sora"
if
cfg
!=
nil
{
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
if
cfg
.
Gateway
.
MaxAccountSwitches
>
0
{
maxAccountSwitches
=
cfg
.
Gateway
.
MaxAccountSwitches
}
if
mode
:=
strings
.
TrimSpace
(
cfg
.
Gateway
.
SoraStreamMode
);
mode
!=
""
{
streamMode
=
mode
}
soraTLSEnabled
=
!
cfg
.
Sora
.
Client
.
DisableTLSFingerprint
signKey
=
strings
.
TrimSpace
(
cfg
.
Gateway
.
SoraMediaSigningKey
)
if
root
:=
strings
.
TrimSpace
(
cfg
.
Sora
.
Storage
.
LocalPath
);
root
!=
""
{
mediaRoot
=
root
}
}
return
&
SoraGatewayHandler
{
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
billingCacheService
:
billingCacheService
,
usageRecordWorkerPool
:
usageRecordWorkerPool
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
streamMode
:
strings
.
ToLower
(
streamMode
),
soraTLSEnabled
:
soraTLSEnabled
,
soraMediaSigningKey
:
signKey
,
soraMediaRoot
:
mediaRoot
,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func
(
h
*
SoraGatewayHandler
)
ChatCompletions
(
c
*
gin
.
Context
)
{
apiKey
,
ok
:=
middleware2
.
GetAPIKeyFromContext
(
c
)
if
!
ok
{
h
.
errorResponse
(
c
,
http
.
StatusUnauthorized
,
"authentication_error"
,
"Invalid API key"
)
return
}
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"User context not found"
)
return
}
reqLog
:=
requestLogger
(
c
,
"handler.sora_gateway.chat_completions"
,
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
)
body
,
err
:=
pkghttputil
.
ReadRequestBodyWithPrealloc
(
c
.
Request
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to read request body"
)
return
}
if
len
(
body
)
==
0
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Request body is empty"
)
return
}
setOpsRequestContext
(
c
,
""
,
false
,
body
)
// 校验请求体 JSON 合法性
if
!
gjson
.
ValidBytes
(
body
)
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult
:=
gjson
.
GetBytes
(
body
,
"model"
)
if
!
modelResult
.
Exists
()
||
modelResult
.
Type
!=
gjson
.
String
||
modelResult
.
String
()
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
return
}
reqModel
:=
modelResult
.
String
()
msgsResult
:=
gjson
.
GetBytes
(
body
,
"messages"
)
if
!
msgsResult
.
IsArray
()
||
len
(
msgsResult
.
Array
())
==
0
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"messages is required"
)
return
}
clientStream
:=
gjson
.
GetBytes
(
body
,
"stream"
)
.
Bool
()
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
reqModel
),
zap
.
Bool
(
"stream"
,
clientStream
))
if
!
clientStream
{
if
h
.
streamMode
==
"error"
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Sora requires stream=true"
)
return
}
var
err
error
body
,
err
=
sjson
.
SetBytes
(
body
,
"stream"
,
true
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to process request"
)
return
}
}
setOpsRequestContext
(
c
,
reqModel
,
clientStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
clientStream
,
false
)))
platform
:=
""
if
forced
,
ok
:=
middleware2
.
GetForcePlatformFromContext
(
c
);
ok
{
platform
=
forced
}
else
if
apiKey
.
Group
!=
nil
{
platform
=
apiKey
.
Group
.
Platform
}
if
platform
!=
service
.
PlatformSora
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"This endpoint only supports Sora platform"
)
return
}
streamStarted
:=
false
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
maxWait
:=
service
.
CalculateMaxWait
(
subject
.
Concurrency
)
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
,
maxWait
)
waitCounted
:=
false
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.user_wait_counter_increment_failed"
,
zap
.
Error
(
err
))
}
else
if
!
canWait
{
reqLog
.
Info
(
"sora.user_wait_queue_full"
,
zap
.
Int
(
"max_wait"
,
maxWait
))
h
.
errorResponse
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
)
return
}
if
err
==
nil
&&
canWait
{
waitCounted
=
true
}
defer
func
()
{
if
waitCounted
{
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
}
}()
userReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireUserSlotWithWait
(
c
,
subject
.
UserID
,
subject
.
Concurrency
,
clientStream
,
&
streamStarted
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.user_slot_acquire_failed"
,
zap
.
Error
(
err
))
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
return
}
if
waitCounted
{
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
waitCounted
=
false
}
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
defer
userReleaseFunc
()
}
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"sora.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
sessionHash
:=
generateOpenAISessionHash
(
c
,
body
)
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
var
lastFailoverBody
[]
byte
var
lastFailoverHeaders
http
.
Header
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
,
""
,
int64
(
0
))
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_select_failed"
,
zap
.
Error
(
err
),
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)),
)
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int
(
"last_upstream_status"
,
lastFailoverStatus
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.failover_exhausted_no_available_accounts"
,
fields
...
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
lastFailoverHeaders
,
lastFailoverBody
,
streamStarted
)
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
proxyBound
:=
account
.
ProxyID
!=
nil
proxyID
:=
int64
(
0
)
if
account
.
ProxyID
!=
nil
{
proxyID
=
*
account
.
ProxyID
}
tlsFingerprintEnabled
:=
h
.
soraTLSEnabled
accountReleaseFunc
:=
selection
.
ReleaseFunc
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_wait_counter_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
}
else
if
!
canWait
{
reqLog
.
Info
(
"sora.account_wait_queue_full"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"max_waiting"
,
selection
.
WaitPlan
.
MaxWaiting
),
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}()
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
clientStream
,
&
streamStarted
,
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_slot_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
result
,
err
:=
h
.
soraGatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
,
clientStream
)
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
lastFailoverStatus
=
failoverErr
.
StatusCode
lastFailoverHeaders
=
cloneHTTPHeaders
(
failoverErr
.
ResponseHeaders
)
lastFailoverBody
=
failoverErr
.
ResponseBody
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.upstream_failover_exhausted"
,
fields
...
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
lastFailoverHeaders
,
lastFailoverBody
,
streamStarted
)
return
}
lastFailoverStatus
=
failoverErr
.
StatusCode
lastFailoverHeaders
=
cloneHTTPHeaders
(
failoverErr
.
ResponseHeaders
)
lastFailoverBody
=
failoverErr
.
ResponseBody
switchCount
++
upstreamErrCode
,
upstreamErrMsg
:=
extractUpstreamErrorCodeAndMessage
(
lastFailoverBody
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
String
(
"upstream_error_code"
,
upstreamErrCode
),
zap
.
String
(
"upstream_error_message"
,
upstreamErrMsg
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.upstream_failover_switching"
,
fields
...
)
continue
}
reqLog
.
Error
(
"sora.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
return
}
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
requestPayloadHash
:=
service
.
HashUsageRequestPayload
(
body
)
inboundEndpoint
:=
GetInboundEndpoint
(
c
)
upstreamEndpoint
:=
GetUpstreamEndpoint
(
c
,
account
.
Platform
)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
Result
:
result
,
APIKey
:
apiKey
,
User
:
apiKey
.
User
,
Account
:
account
,
Subscription
:
subscription
,
InboundEndpoint
:
inboundEndpoint
,
UpstreamEndpoint
:
upstreamEndpoint
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.sora_gateway.chat_completions"
),
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
zap
.
String
(
"model"
,
reqModel
),
zap
.
Int64
(
"account_id"
,
account
.
ID
),
)
.
Error
(
"sora.record_usage_failed"
,
zap
.
Error
(
err
))
}
})
reqLog
.
Debug
(
"sora.request_completed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"switch_count"
,
switchCount
),
)
return
}
}
func
generateOpenAISessionHash
(
c
*
gin
.
Context
,
body
[]
byte
)
string
{
if
c
==
nil
{
return
""
}
sessionID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"session_id"
))
if
sessionID
==
""
{
sessionID
=
strings
.
TrimSpace
(
c
.
GetHeader
(
"conversation_id"
))
}
if
sessionID
==
""
&&
len
(
body
)
>
0
{
sessionID
=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"prompt_cache_key"
)
.
String
())
}
if
sessionID
==
""
{
return
""
}
hash
:=
sha256
.
Sum256
([]
byte
(
sessionID
))
return
hex
.
EncodeToString
(
hash
[
:
])
}
func
(
h
*
SoraGatewayHandler
)
submitUsageRecordTask
(
task
service
.
UsageRecordTask
)
{
if
task
==
nil
{
return
}
if
h
.
usageRecordWorkerPool
!=
nil
{
h
.
usageRecordWorkerPool
.
Submit
(
task
)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
defer
func
()
{
if
recovered
:=
recover
();
recovered
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.sora_gateway.chat_completions"
),
zap
.
Any
(
"panic"
,
recovered
),
)
.
Error
(
"sora.usage_record_task_panic_recovered"
)
}
}()
task
(
ctx
)
}
func
(
h
*
SoraGatewayHandler
)
handleConcurrencyError
(
c
*
gin
.
Context
,
err
error
,
slotType
string
,
streamStarted
bool
)
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
}
func
(
h
*
SoraGatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
statusCode
int
,
responseHeaders
http
.
Header
,
responseBody
[]
byte
,
streamStarted
bool
)
{
upstreamMsg
:=
service
.
ExtractUpstreamErrorMessage
(
responseBody
)
service
.
SetOpsUpstreamError
(
c
,
statusCode
,
upstreamMsg
,
""
)
status
,
errType
,
errMsg
:=
h
.
mapUpstreamError
(
statusCode
,
responseHeaders
,
responseBody
)
h
.
handleStreamingAwareError
(
c
,
status
,
errType
,
errMsg
,
streamStarted
)
}
func
(
h
*
SoraGatewayHandler
)
mapUpstreamError
(
statusCode
int
,
responseHeaders
http
.
Header
,
responseBody
[]
byte
)
(
int
,
string
,
string
)
{
if
isSoraCloudflareChallengeResponse
(
statusCode
,
responseHeaders
,
responseBody
)
{
baseMsg
:=
fmt
.
Sprintf
(
"Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry."
,
statusCode
)
return
http
.
StatusBadGateway
,
"upstream_error"
,
formatSoraCloudflareChallengeMessage
(
baseMsg
,
responseHeaders
,
responseBody
)
}
upstreamCode
,
upstreamMessage
:=
extractUpstreamErrorCodeAndMessage
(
responseBody
)
if
strings
.
EqualFold
(
upstreamCode
,
"cf_shield_429"
)
{
baseMsg
:=
"Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
formatSoraCloudflareChallengeMessage
(
baseMsg
,
responseHeaders
,
responseBody
)
}
if
shouldPassthroughSoraUpstreamMessage
(
statusCode
,
upstreamMessage
)
{
switch
statusCode
{
case
401
,
403
,
404
,
500
,
502
,
503
,
504
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
upstreamMessage
case
429
:
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
upstreamMessage
}
}
switch
statusCode
{
case
401
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream authentication failed, please contact administrator"
case
403
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream access forbidden, please contact administrator"
case
404
:
if
strings
.
EqualFold
(
upstreamCode
,
"unsupported_country_code"
)
{
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream region capability unavailable for this account, please contact administrator"
}
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream capability unavailable for this account, please contact administrator"
case
429
:
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Upstream rate limit exceeded, please retry later"
case
529
:
return
http
.
StatusServiceUnavailable
,
"upstream_error"
,
"Upstream service overloaded, please retry later"
case
500
,
502
,
503
,
504
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream service temporarily unavailable"
default
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
}
}
func
cloneHTTPHeaders
(
headers
http
.
Header
)
http
.
Header
{
if
headers
==
nil
{
return
nil
}
return
headers
.
Clone
()
}
func
extractSoraFailoverHeaderInsights
(
headers
http
.
Header
,
body
[]
byte
)
(
rayID
,
mitigated
,
contentType
string
)
{
if
headers
!=
nil
{
mitigated
=
strings
.
TrimSpace
(
headers
.
Get
(
"cf-mitigated"
))
contentType
=
strings
.
TrimSpace
(
headers
.
Get
(
"content-type"
))
if
contentType
==
""
{
contentType
=
strings
.
TrimSpace
(
headers
.
Get
(
"Content-Type"
))
}
}
rayID
=
soraerror
.
ExtractCloudflareRayID
(
headers
,
body
)
return
rayID
,
mitigated
,
contentType
}
func
isSoraCloudflareChallengeResponse
(
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
bool
{
return
soraerror
.
IsCloudflareChallengeResponse
(
statusCode
,
headers
,
body
)
}
func
shouldPassthroughSoraUpstreamMessage
(
statusCode
int
,
message
string
)
bool
{
message
=
strings
.
TrimSpace
(
message
)
if
message
==
""
{
return
false
}
if
statusCode
==
http
.
StatusForbidden
||
statusCode
==
http
.
StatusTooManyRequests
{
lower
:=
strings
.
ToLower
(
message
)
if
strings
.
Contains
(
lower
,
"<html"
)
||
strings
.
Contains
(
lower
,
"<!doctype html"
)
||
strings
.
Contains
(
lower
,
"window._cf_chl_opt"
)
{
return
false
}
}
return
true
}
func
formatSoraCloudflareChallengeMessage
(
base
string
,
headers
http
.
Header
,
body
[]
byte
)
string
{
return
soraerror
.
FormatCloudflareChallengeMessage
(
base
,
headers
,
body
)
}
func
extractUpstreamErrorCodeAndMessage
(
body
[]
byte
)
(
string
,
string
)
{
return
soraerror
.
ExtractUpstreamErrorCodeAndMessage
(
body
)
}
func
(
h
*
SoraGatewayHandler
)
handleStreamingAwareError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
,
streamStarted
bool
)
{
if
streamStarted
{
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
)
if
ok
{
errorData
:=
map
[
string
]
any
{
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"event: error
\n
data: %s
\n\n
"
,
string
(
jsonBytes
))
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
);
err
!=
nil
{
_
=
c
.
Error
(
err
)
}
flusher
.
Flush
()
}
return
}
h
.
errorResponse
(
c
,
status
,
errType
,
message
)
}
func
(
h
*
SoraGatewayHandler
)
errorResponse
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
message
,
},
})
}
// MediaProxy serves local Sora media files.
func
(
h
*
SoraGatewayHandler
)
MediaProxy
(
c
*
gin
.
Context
)
{
h
.
proxySoraMedia
(
c
,
false
)
}
// MediaProxySigned serves local Sora media files with signature verification.
func
(
h
*
SoraGatewayHandler
)
MediaProxySigned
(
c
*
gin
.
Context
)
{
h
.
proxySoraMedia
(
c
,
true
)
}
func
(
h
*
SoraGatewayHandler
)
proxySoraMedia
(
c
*
gin
.
Context
,
requireSignature
bool
)
{
rawPath
:=
c
.
Param
(
"filepath"
)
if
rawPath
==
""
{
c
.
Status
(
http
.
StatusNotFound
)
return
}
cleaned
:=
path
.
Clean
(
rawPath
)
if
!
strings
.
HasPrefix
(
cleaned
,
"/image/"
)
&&
!
strings
.
HasPrefix
(
cleaned
,
"/video/"
)
{
c
.
Status
(
http
.
StatusNotFound
)
return
}
query
:=
c
.
Request
.
URL
.
Query
()
if
requireSignature
{
if
h
.
soraMediaSigningKey
==
""
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 媒体签名未配置"
,
},
})
return
}
expiresStr
:=
strings
.
TrimSpace
(
query
.
Get
(
"expires"
))
signature
:=
strings
.
TrimSpace
(
query
.
Get
(
"sig"
))
expires
,
err
:=
strconv
.
ParseInt
(
expiresStr
,
10
,
64
)
if
err
!=
nil
||
expires
<=
time
.
Now
()
.
Unix
()
{
c
.
JSON
(
http
.
StatusUnauthorized
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"authentication_error"
,
"message"
:
"Sora 媒体签名已过期"
,
},
})
return
}
query
.
Del
(
"sig"
)
query
.
Del
(
"expires"
)
signingQuery
:=
query
.
Encode
()
if
!
service
.
VerifySoraMediaURL
(
cleaned
,
signingQuery
,
expires
,
signature
,
h
.
soraMediaSigningKey
)
{
c
.
JSON
(
http
.
StatusUnauthorized
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"authentication_error"
,
"message"
:
"Sora 媒体签名无效"
,
},
})
return
}
}
if
strings
.
TrimSpace
(
h
.
soraMediaRoot
)
==
""
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 媒体目录未配置"
,
},
})
return
}
relative
:=
strings
.
TrimPrefix
(
cleaned
,
"/"
)
localPath
:=
filepath
.
Join
(
h
.
soraMediaRoot
,
filepath
.
FromSlash
(
relative
))
if
_
,
err
:=
os
.
Stat
(
localPath
);
err
!=
nil
{
if
os
.
IsNotExist
(
err
)
{
c
.
Status
(
http
.
StatusNotFound
)
return
}
c
.
Status
(
http
.
StatusInternalServerError
)
return
}
c
.
File
(
localPath
)
}
backend/internal/handler/sora_gateway_handler_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
handler
import
(
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
var
_
service
.
SoraClient
=
(
*
stubSoraClient
)(
nil
)
var
_
service
.
AccountRepository
=
(
*
stubAccountRepo
)(
nil
)
var
_
service
.
GroupRepository
=
(
*
stubGroupRepo
)(
nil
)
var
_
service
.
UsageLogRepository
=
(
*
stubUsageLogRepo
)(
nil
)
type
stubSoraClient
struct
{
imageURLs
[]
string
}
func
(
s
*
stubSoraClient
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClient
)
UploadImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
{
return
"upload"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraVideoRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraStoryboardRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClient
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"cameo-1"
,
nil
}
func
(
s
*
stubSoraClient
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
service
.
Account
,
cameoID
string
)
(
*
service
.
SoraCameoStatus
,
error
)
{
return
&
service
.
SoraCameoStatus
{
Status
:
"finalized"
,
StatusMessage
:
"Completed"
,
DisplayNameHint
:
"Character"
,
UsernameHint
:
"user.character"
,
ProfileAssetURL
:
"https://example.com/avatar.webp"
,
},
nil
}
func
(
s
*
stubSoraClient
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
imageURL
string
)
([]
byte
,
error
)
{
return
[]
byte
(
"avatar"
),
nil
}
func
(
s
*
stubSoraClient
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"asset-pointer"
,
nil
}
func
(
s
*
stubSoraClient
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
"character-1"
,
nil
}
func
(
s
*
stubSoraClient
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
service
.
Account
,
cameoID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
service
.
Account
,
characterID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
service
.
Account
,
generationID
string
)
(
string
,
error
)
{
return
"s_post"
,
nil
}
func
(
s
*
stubSoraClient
)
DeletePost
(
ctx
context
.
Context
,
account
*
service
.
Account
,
postID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
service
.
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
return
"https://example.com/no-watermark.mp4"
,
nil
}
func
(
s
*
stubSoraClient
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
service
.
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
return
"enhanced prompt"
,
nil
}
func
(
s
*
stubSoraClient
)
GetImageTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
taskID
string
)
(
*
service
.
SoraImageTaskStatus
,
error
)
{
return
&
service
.
SoraImageTaskStatus
{
ID
:
taskID
,
Status
:
"completed"
,
URLs
:
s
.
imageURLs
},
nil
}
func
(
s
*
stubSoraClient
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
taskID
string
)
(
*
service
.
SoraVideoTaskStatus
,
error
)
{
return
&
service
.
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"completed"
,
URLs
:
s
.
imageURLs
},
nil
}
type
stubAccountRepo
struct
{
accounts
map
[
int64
]
*
service
.
Account
}
func
(
r
*
stubAccountRepo
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
if
acc
,
ok
:=
r
.
accounts
[
id
];
ok
{
return
acc
,
nil
}
return
nil
,
service
.
ErrAccountNotFound
}
func
(
r
*
stubAccountRepo
)
GetByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
service
.
Account
,
error
)
{
var
result
[]
*
service
.
Account
for
_
,
id
:=
range
ids
{
if
acc
,
ok
:=
r
.
accounts
[
id
];
ok
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
r
*
stubAccountRepo
)
ExistsByID
(
ctx
context
.
Context
,
id
int64
)
(
bool
,
error
)
{
_
,
ok
:=
r
.
accounts
[
id
]
return
ok
,
nil
}
func
(
r
*
stubAccountRepo
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
map
[
string
]
int64
{},
nil
}
func
(
r
*
stubAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulableByPlatform
(
platform
),
nil
}
func
(
r
*
stubAccountRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
AutoPauseExpiredAccounts
(
ctx
context
.
Context
,
now
time
.
Time
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepo
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulable
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulable
(),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulable
(),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulableByPlatform
(
platform
),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulableByPlatform
(
platform
),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
var
result
[]
service
.
Account
for
_
,
acc
:=
range
r
.
accounts
{
for
_
,
platform
:=
range
platforms
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
*
acc
)
break
}
}
}
return
result
,
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
r
*
stubAccountRepo
)
ListSchedulableUngroupedByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
func
(
r
*
stubAccountRepo
)
ListSchedulableUngroupedByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
r
*
stubAccountRepo
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ResetQuotaUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
listSchedulable
()
[]
service
.
Account
{
var
result
[]
service
.
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
*
acc
)
}
}
return
result
}
func
(
r
*
stubAccountRepo
)
listSchedulableByPlatform
(
platform
string
)
[]
service
.
Account
{
var
result
[]
service
.
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
*
acc
)
}
}
return
result
}
type
stubGroupRepo
struct
{
group
*
service
.
Group
}
func
(
r
*
stubGroupRepo
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
r
.
group
,
nil
}
func
(
r
*
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
r
.
group
,
nil
}
func
(
r
*
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubGroupRepo
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
int64
,
error
)
{
return
0
,
0
,
nil
}
func
(
r
*
stubGroupRepo
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubGroupRepo
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
service
.
GroupSortOrderUpdate
)
error
{
return
nil
}
type
stubUsageLogRepo
struct
{}
func
(
s
*
stubUsageLogRepo
)
Create
(
ctx
context
.
Context
,
log
*
service
.
UsageLog
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
UsageLog
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAPIKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAPIKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountTodayStats
(
ctx
context
.
Context
,
accountID
int64
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
[]
usagestats
.
EndpointStat
{},
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUpstreamEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
[]
usagestats
.
EndpointStat
{},
nil
}
func
(
s
*
stubUsageLogRepo
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserBreakdownStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
dim
usagestats
.
UserBreakdownDimension
,
limit
int
)
([]
usagestats
.
UserBreakdownItem
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAllGroupUsageSummary
(
ctx
context
.
Context
,
todayStart
time
.
Time
)
([]
usagestats
.
GroupUsageSummary
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserSpendingRanking
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
limit
int
)
(
*
usagestats
.
UserSpendingRankingResponse
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserDashboardStats
(
ctx
context
.
Context
,
userID
int64
)
(
*
usagestats
.
UserDashboardStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAPIKeyDashboardStats
(
ctx
context
.
Context
,
apiKeyID
int64
)
(
*
usagestats
.
UserDashboardStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserModelStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetGlobalStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetStatsWithFilters
(
ctx
context
.
Context
,
filters
usagestats
.
UsageLogFilters
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
AccountUsageStatsResponse
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAPIKeyStatsAggregated
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountStatsAggregated
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetModelStatsAggregated
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetDailyStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
map
[
string
]
any
,
error
)
{
return
nil
,
nil
}
func
TestSoraGatewayHandler_ChatCompletions
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
,
Gateway
:
config
.
GatewayConfig
{
SoraStreamMode
:
"force"
,
MaxAccountSwitches
:
1
,
Scheduling
:
config
.
GatewaySchedulingConfig
{
LoadBatchEnabled
:
false
,
},
},
Concurrency
:
config
.
ConcurrencyConfig
{
PingInterval
:
0
},
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
BaseURL
:
"https://sora.test"
,
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
account
:=
&
service
.
Account
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
}
accountRepo
:=
&
stubAccountRepo
{
accounts
:
map
[
int64
]
*
service
.
Account
{
account
.
ID
:
account
}}
group
:=
&
service
.
Group
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
}
groupRepo
:=
&
stubGroupRepo
{
group
:
group
}
usageLogRepo
:=
&
stubUsageLogRepo
{}
deferredService
:=
service
.
NewDeferredService
(
accountRepo
,
nil
,
0
)
billingService
:=
service
.
NewBillingService
(
cfg
,
nil
)
concurrencyService
:=
service
.
NewConcurrencyService
(
testutil
.
StubConcurrencyCache
{})
billingCacheService
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
cfg
)
t
.
Cleanup
(
func
()
{
billingCacheService
.
Stop
()
})
gatewayService
:=
service
.
NewGatewayService
(
accountRepo
,
groupRepo
,
usageLogRepo
,
nil
,
nil
,
nil
,
nil
,
testutil
.
StubGatewayCache
{},
cfg
,
nil
,
concurrencyService
,
billingService
,
nil
,
billingCacheService
,
nil
,
nil
,
deferredService
,
nil
,
testutil
.
StubSessionLimitCache
{},
nil
,
// rpmCache
nil
,
// digestStore
nil
,
// settingService
nil
,
// tlsFPProfileService
nil
,
// channelService
nil
,
// resolver
)
soraClient
:=
&
stubSoraClient
{
imageURLs
:
[]
string
{
"https://example.com/a.png"
}}
soraGatewayService
:=
service
.
NewSoraGatewayService
(
soraClient
,
nil
,
nil
,
cfg
)
handler
:=
NewSoraGatewayHandler
(
gatewayService
,
soraGatewayService
,
concurrencyService
,
billingCacheService
,
nil
,
cfg
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
`{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/sora/v1/chat/completions"
,
strings
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
apiKey
:=
&
service
.
APIKey
{
ID
:
1
,
UserID
:
1
,
Status
:
service
.
StatusActive
,
GroupID
:
&
group
.
ID
,
User
:
&
service
.
User
{
ID
:
1
,
Concurrency
:
1
,
Status
:
service
.
StatusActive
},
Group
:
group
,
}
c
.
Set
(
string
(
middleware
.
ContextKeyAPIKey
),
apiKey
)
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
UserID
:
apiKey
.
UserID
,
Concurrency
:
apiKey
.
User
.
Concurrency
})
handler
.
ChatCompletions
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
NotEmpty
(
t
,
resp
[
"media_url"
])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func
TestSoraHandler_StreamForcing
(
t
*
testing
.
T
)
{
// 测试 1:stream=false 时 sjson 强制修改为 true
body
:=
[]
byte
(
`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`
)
clientStream
:=
gjson
.
GetBytes
(
body
,
"stream"
)
.
Bool
()
require
.
False
(
t
,
clientStream
)
newBody
,
err
:=
sjson
.
SetBytes
(
body
,
"stream"
,
true
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
gjson
.
GetBytes
(
newBody
,
"stream"
)
.
Bool
())
// 测试 2:stream=true 时不修改
body2
:=
[]
byte
(
`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`
)
require
.
True
(
t
,
gjson
.
GetBytes
(
body2
,
"stream"
)
.
Bool
())
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
body3
:=
[]
byte
(
`{"model":"sora","messages":[{"role":"user","content":"test"}]}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body3
,
"stream"
)
.
Bool
())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func
TestSoraHandler_ValidationExtraction
(
t
*
testing
.
T
)
{
// model 缺失
body
:=
[]
byte
(
`{"messages":[{"role":"user","content":"test"}]}`
)
modelResult
:=
gjson
.
GetBytes
(
body
,
"model"
)
require
.
True
(
t
,
!
modelResult
.
Exists
()
||
modelResult
.
Type
!=
gjson
.
String
||
modelResult
.
String
()
==
""
)
// model 为数字 → 类型不是 gjson.String,应被拒绝
body1b
:=
[]
byte
(
`{"model":123,"messages":[{"role":"user","content":"test"}]}`
)
modelResult1b
:=
gjson
.
GetBytes
(
body1b
,
"model"
)
require
.
True
(
t
,
modelResult1b
.
Exists
())
require
.
NotEqual
(
t
,
gjson
.
String
,
modelResult1b
.
Type
)
// messages 缺失
body2
:=
[]
byte
(
`{"model":"sora"}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body2
,
"messages"
)
.
IsArray
())
// messages 不是 JSON 数组(字符串)
body3
:=
[]
byte
(
`{"model":"sora","messages":"not array"}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body3
,
"messages"
)
.
IsArray
())
// messages 是对象而非数组 → IsArray 返回 false
body4
:=
[]
byte
(
`{"model":"sora","messages":{}}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body4
,
"messages"
)
.
IsArray
())
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
body5
:=
[]
byte
(
`{"model":"sora","messages":[]}`
)
msgsResult
:=
gjson
.
GetBytes
(
body5
,
"messages"
)
require
.
True
(
t
,
msgsResult
.
IsArray
())
require
.
Equal
(
t
,
0
,
len
(
msgsResult
.
Array
()))
// 非法 JSON 被 gjson.ValidBytes 拦截
require
.
False
(
t
,
gjson
.
ValidBytes
([]
byte
(
`{invalid`
)))
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func
TestGenerateOpenAISessionHash_WithBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
// 从 body 提取 prompt_cache_key
body
:=
[]
byte
(
`{"model":"sora","prompt_cache_key":"session-abc"}`
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
"POST"
,
"/"
,
nil
)
hash
:=
generateOpenAISessionHash
(
c
,
body
)
require
.
NotEmpty
(
t
,
hash
)
// 无 prompt_cache_key 且无 header → 空 hash
body2
:=
[]
byte
(
`{"model":"sora"}`
)
hash2
:=
generateOpenAISessionHash
(
c
,
body2
)
require
.
Empty
(
t
,
hash2
)
// header 优先于 body
c
.
Request
.
Header
.
Set
(
"session_id"
,
"from-header"
)
hash3
:=
generateOpenAISessionHash
(
c
,
body
)
require
.
NotEmpty
(
t
,
hash3
)
require
.
NotEqual
(
t
,
hash
,
hash3
)
// 不同来源应产生不同 hash
}
func
TestSoraHandleStreamingAwareError_JSONEscaping
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
errType
string
message
string
}{
{
name
:
"包含双引号"
,
errType
:
"upstream_error"
,
message
:
`upstream returned "invalid" payload`
,
},
{
name
:
"包含换行和制表符"
,
errType
:
"rate_limit_error"
,
message
:
"line1
\n
line2
\t
tab"
,
},
{
name
:
"包含反斜杠"
,
errType
:
"upstream_error"
,
message
:
`path C:\Users\test\file.txt not found`
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleStreamingAwareError
(
c
,
http
.
StatusBadGateway
,
tt
.
errType
,
tt
.
message
,
true
)
body
:=
w
.
Body
.
String
()
require
.
True
(
t
,
strings
.
HasPrefix
(
body
,
"event: error
\n
"
),
"应以 SSE error 事件开头"
)
require
.
True
(
t
,
strings
.
HasSuffix
(
body
,
"
\n\n
"
),
"应以 SSE 结束分隔符结尾"
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
body
,
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
,
"SSE 错误事件应包含 event 行和 data 行"
)
require
.
Equal
(
t
,
"event: error"
,
lines
[
0
])
require
.
True
(
t
,
strings
.
HasPrefix
(
lines
[
1
],
"data: "
),
"第二行应为 data 前缀"
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
),
"data 行必须是合法 JSON"
)
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
,
"JSON 中应包含 error 对象"
)
require
.
Equal
(
t
,
tt
.
errType
,
errorObj
[
"type"
])
require
.
Equal
(
t
,
tt
.
message
,
errorObj
[
"message"
])
})
}
}
func
TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
SoraGatewayHandler
{}
resp
:=
[]
byte
(
`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`
)
h
.
handleFailoverExhausted
(
c
,
http
.
StatusBadGateway
,
nil
,
resp
,
true
)
body
:=
w
.
Body
.
String
()
require
.
True
(
t
,
strings
.
HasPrefix
(
body
,
"event: error
\n
"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
body
,
"
\n\n
"
))
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
body
,
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
require
.
Equal
(
t
,
"invalid
\"
prompt
\"\n
line2"
,
errorObj
[
"message"
])
}
func
TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-ray"
,
"9d01b0e9ecc35829-SEA"
)
body
:=
[]
byte
(
`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
http
.
StatusForbidden
,
headers
,
body
,
true
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
w
.
Body
.
String
(),
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
msg
,
_
:=
errorObj
[
"message"
]
.
(
string
)
require
.
Contains
(
t
,
msg
,
"Cloudflare challenge"
)
require
.
Contains
(
t
,
msg
,
"cf-ray: 9d01b0e9ecc35829-SEA"
)
}
func
TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-ray"
,
"9d03b68c086027a1-SEA"
)
body
:=
[]
byte
(
`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
http
.
StatusTooManyRequests
,
headers
,
body
,
true
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
w
.
Body
.
String
(),
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"rate_limit_error"
,
errorObj
[
"type"
])
msg
,
_
:=
errorObj
[
"message"
]
.
(
string
)
require
.
Contains
(
t
,
msg
,
"Cloudflare shield"
)
require
.
Contains
(
t
,
msg
,
"cf-ray: 9d03b68c086027a1-SEA"
)
}
func
TestExtractSoraFailoverHeaderInsights
(
t
*
testing
.
T
)
{
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-mitigated"
,
"challenge"
)
headers
.
Set
(
"content-type"
,
"text/html"
)
body
:=
[]
byte
(
`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
headers
,
body
)
require
.
Equal
(
t
,
"9cff2d62d83bb98d"
,
rayID
)
require
.
Equal
(
t
,
"challenge"
,
mitigated
)
require
.
Equal
(
t
,
"text/html"
,
contentType
)
}
backend/internal/handler/usage_record_submit_task_test.go
View file @
eb2dce92
...
...
@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
require
.
True
(
t
,
called
.
Load
(),
"panic 后后续任务应仍可执行"
)
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool
(
t
*
testing
.
T
)
{
pool
:=
newUsageRecordTestPool
(
t
)
h
:=
&
SoraGatewayHandler
{
usageRecordWorkerPool
:
pool
}
done
:=
make
(
chan
struct
{})
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
close
(
done
)
})
select
{
case
<-
done
:
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"task not executed"
)
}
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback
(
t
*
testing
.
T
)
{
h
:=
&
SoraGatewayHandler
{}
var
called
atomic
.
Bool
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
_
,
ok
:=
ctx
.
Deadline
();
!
ok
{
t
.
Fatal
(
"expected deadline in fallback context"
)
}
called
.
Store
(
true
)
})
require
.
True
(
t
,
called
.
Load
())
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask
(
t
*
testing
.
T
)
{
h
:=
&
SoraGatewayHandler
{}
require
.
NotPanics
(
t
,
func
()
{
h
.
submitUsageRecordTask
(
nil
)
})
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered
(
t
*
testing
.
T
)
{
h
:=
&
SoraGatewayHandler
{}
var
called
atomic
.
Bool
require
.
NotPanics
(
t
,
func
()
{
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
panic
(
"usage task panic"
)
})
})
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
called
.
Store
(
true
)
})
require
.
True
(
t
,
called
.
Load
(),
"panic 后后续任务应仍可执行"
)
}
backend/internal/handler/wire.go
View file @
eb2dce92
...
...
@@ -86,8 +86,6 @@ func ProvideHandlers(
adminHandlers
*
AdminHandlers
,
gatewayHandler
*
GatewayHandler
,
openaiGatewayHandler
*
OpenAIGatewayHandler
,
soraGatewayHandler
*
SoraGatewayHandler
,
soraClientHandler
*
SoraClientHandler
,
settingHandler
*
SettingHandler
,
totpHandler
*
TotpHandler
,
_
*
service
.
IdempotencyCoordinator
,
...
...
@@ -104,8 +102,6 @@ func ProvideHandlers(
Admin
:
adminHandlers
,
Gateway
:
gatewayHandler
,
OpenAIGateway
:
openaiGatewayHandler
,
SoraGateway
:
soraGatewayHandler
,
SoraClient
:
soraClientHandler
,
Setting
:
settingHandler
,
Totp
:
totpHandler
,
}
...
...
@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler
,
NewGatewayHandler
,
NewOpenAIGatewayHandler
,
NewSoraGatewayHandler
,
NewTotpHandler
,
ProvideSettingHandler
,
...
...
backend/internal/pkg/antigravity/oauth.go
View file @
eb2dce92
...
...
@@ -50,7 +50,7 @@ const (
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
var
defaultUserAgentVersion
=
"1.2
0.5
"
var
defaultUserAgentVersion
=
"1.2
1.9
"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var
defaultClientSecret
=
"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"
...
...
backend/internal/pkg/antigravity/oauth_test.go
View file @
eb2dce92
...
...
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
if
RedirectURI
!=
"http://localhost:8085/callback"
{
t
.
Errorf
(
"RedirectURI 不匹配: got %s"
,
RedirectURI
)
}
if
GetUserAgent
()
!=
"antigravity/1.2
0.5
windows/amd64"
{
if
GetUserAgent
()
!=
"antigravity/1.2
1.9
windows/amd64"
{
t
.
Errorf
(
"UserAgent 不匹配: got %s"
,
GetUserAgent
())
}
if
SessionTTL
!=
30
*
time
.
Minute
{
...
...
backend/internal/pkg/openai/oauth.go
View file @
eb2dce92
...
...
@@ -17,8 +17,6 @@ import (
const
(
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID
=
"app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID
=
"app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL
=
"https://auth.openai.com/oauth/authorize"
...
...
@@ -39,8 +37,6 @@ const (
const
(
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI
=
"openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora
=
"sora"
)
// OAuthSession stores OAuth flow state for OpenAI
...
...
@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func
OAuthClientConfigByPlatform
(
platform
string
)
(
clientID
string
,
codexFlow
bool
)
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
platform
))
{
case
OAuthPlatformSora
:
return
ClientID
,
false
default
:
return
ClientID
,
true
}
return
ClientID
,
true
}
// TokenRequest represents the token exchange request body
...
...
backend/internal/pkg/openai/oauth_test.go
View file @
eb2dce92
...
...
@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func
TestBuildAuthorizationURLForPlatform_Sora
(
t
*
testing
.
T
)
{
authURL
:=
BuildAuthorizationURLForPlatform
(
"state-2"
,
"challenge-2"
,
DefaultRedirectURI
,
OAuthPlatformSora
)
parsed
,
err
:=
url
.
Parse
(
authURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"Parse URL failed: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
got
:=
q
.
Get
(
"client_id"
);
got
!=
ClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)"
,
got
,
ClientID
)
}
if
got
:=
q
.
Get
(
"codex_cli_simplified_flow"
);
got
!=
""
{
t
.
Fatalf
(
"codex flow should be empty for sora, got=%q"
,
got
)
}
if
got
:=
q
.
Get
(
"id_token_add_organizations"
);
got
!=
"true"
{
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
backend/internal/repository/account_repo.go
View file @
eb2dce92
...
...
@@ -1692,20 +1692,13 @@ func itoa(v int) string {
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func
(
r
*
accountRepository
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
accounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
Where
(
dbaccount
.
PlatformEQ
(
"sora"
),
// 限定平台为 sora
dbaccount
.
DeletedAtIsNil
(),
func
(
s
*
entsql
.
Selector
)
{
path
:=
sqljson
.
Path
(
key
)
...
...
backend/internal/repository/api_key_repo.go
View file @
eb2dce92
...
...
@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldImagePrice1k
,
group
.
FieldImagePrice2k
,
group
.
FieldImagePrice4k
,
group
.
FieldSoraImagePrice360
,
group
.
FieldSoraImagePrice540
,
group
.
FieldSoraVideoPricePerRequest
,
group
.
FieldSoraVideoPricePerRequestHd
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldFallbackGroupID
,
group
.
FieldFallbackGroupIDOnInvalidRequest
,
...
...
@@ -608,22 +604,20 @@ func userEntityToService(u *dbent.User) *service.User {
return
nil
}
return
&
service
.
User
{
ID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Notes
:
u
.
Notes
,
PasswordHash
:
u
.
PasswordHash
,
Role
:
u
.
Role
,
Balance
:
u
.
Balance
,
Concurrency
:
u
.
Concurrency
,
Status
:
u
.
Status
,
SoraStorageQuotaBytes
:
u
.
SoraStorageQuotaBytes
,
SoraStorageUsedBytes
:
u
.
SoraStorageUsedBytes
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
ID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Notes
:
u
.
Notes
,
PasswordHash
:
u
.
PasswordHash
,
Role
:
u
.
Role
,
Balance
:
u
.
Balance
,
Concurrency
:
u
.
Concurrency
,
Status
:
u
.
Status
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
}
}
...
...
@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K
:
g
.
ImagePrice1k
,
ImagePrice2K
:
g
.
ImagePrice2k
,
ImagePrice4K
:
g
.
ImagePrice4k
,
SoraImagePrice360
:
g
.
SoraImagePrice360
,
SoraImagePrice540
:
g
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
g
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHd
,
SoraStorageQuotaBytes
:
g
.
SoraStorageQuotaBytes
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
...
...
backend/internal/repository/group_repo.go
View file @
eb2dce92
...
...
@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k
(
groupIn
.
ImagePrice1K
)
.
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetNillableSoraImagePrice360
(
groupIn
.
SoraImagePrice360
)
.
SetNillableSoraImagePrice540
(
groupIn
.
SoraImagePrice540
)
.
SetNillableSoraVideoPricePerRequest
(
groupIn
.
SoraVideoPricePerRequest
)
.
SetNillableSoraVideoPricePerRequestHd
(
groupIn
.
SoraVideoPricePerRequestHD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetSoraStorageQuotaBytes
(
groupIn
.
SoraStorageQuotaBytes
)
.
SetAllowMessagesDispatch
(
groupIn
.
AllowMessagesDispatch
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
...
...
@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k
(
groupIn
.
ImagePrice1K
)
.
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetNillableSoraImagePrice360
(
groupIn
.
SoraImagePrice360
)
.
SetNillableSoraImagePrice540
(
groupIn
.
SoraImagePrice540
)
.
SetNillableSoraVideoPricePerRequest
(
groupIn
.
SoraVideoPricePerRequest
)
.
SetNillableSoraVideoPricePerRequestHd
(
groupIn
.
SoraVideoPricePerRequestHD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetSoraStorageQuotaBytes
(
groupIn
.
SoraStorageQuotaBytes
)
.
SetAllowMessagesDispatch
(
groupIn
.
AllowMessagesDispatch
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
...
...
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment