Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
62e80c60
Commit
62e80c60
authored
Apr 05, 2026
by
erio
Browse files
revert: completely remove all Sora functionality
parent
dbb248df
Changes
136
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/sora_client_handler.go
deleted
100644 → 0
View file @
dbb248df
package
handler
import
(
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const
(
// 上游模型缓存 TTL
modelCacheTTL
=
1
*
time
.
Hour
// 上游获取成功
modelCacheFailedTTL
=
2
*
time
.
Minute
// 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type
SoraClientHandler
struct
{
genService
*
service
.
SoraGenerationService
quotaService
*
service
.
SoraQuotaService
s3Storage
*
service
.
SoraS3Storage
soraGatewayService
*
service
.
SoraGatewayService
gatewayService
*
service
.
GatewayService
mediaStorage
*
service
.
SoraMediaStorage
apiKeyService
*
service
.
APIKeyService
// 上游模型缓存
modelCacheMu
sync
.
RWMutex
cachedFamilies
[]
service
.
SoraModelFamily
modelCacheTime
time
.
Time
modelCacheUpstream
bool
// 是否来自上游(决定 TTL)
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func
NewSoraClientHandler
(
genService
*
service
.
SoraGenerationService
,
quotaService
*
service
.
SoraQuotaService
,
s3Storage
*
service
.
SoraS3Storage
,
soraGatewayService
*
service
.
SoraGatewayService
,
gatewayService
*
service
.
GatewayService
,
mediaStorage
*
service
.
SoraMediaStorage
,
apiKeyService
*
service
.
APIKeyService
,
)
*
SoraClientHandler
{
return
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
s3Storage
:
s3Storage
,
soraGatewayService
:
soraGatewayService
,
gatewayService
:
gatewayService
,
mediaStorage
:
mediaStorage
,
apiKeyService
:
apiKeyService
,
}
}
// GenerateRequest 生成请求。
type
GenerateRequest
struct
{
Model
string
`json:"model" binding:"required"`
Prompt
string
`json:"prompt" binding:"required"`
MediaType
string
`json:"media_type"`
// video / image,默认 video
VideoCount
int
`json:"video_count,omitempty"`
// 视频数量(1-3)
ImageInput
string
`json:"image_input,omitempty"`
// 参考图(base64 或 URL)
APIKeyID
*
int64
`json:"api_key_id,omitempty"`
// 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func
(
h
*
SoraClientHandler
)
Generate
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
var
req
GenerateRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"参数错误: "
+
err
.
Error
())
return
}
if
req
.
MediaType
==
""
{
req
.
MediaType
=
"video"
}
req
.
VideoCount
=
normalizeVideoCount
(
req
.
MediaType
,
req
.
VideoCount
)
// 并发数检查(最多 3 个)
activeCount
,
err
:=
h
.
genService
.
CountActiveByUser
(
c
.
Request
.
Context
(),
userID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
activeCount
>=
3
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"同时进行中的任务不能超过 3 个"
)
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if
h
.
quotaService
!=
nil
{
if
err
:=
h
.
quotaService
.
CheckQuota
(
c
.
Request
.
Context
(),
userID
,
0
);
err
!=
nil
{
var
quotaErr
*
service
.
QuotaExceededError
if
errors
.
As
(
err
,
&
quotaErr
)
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"存储配额已满,请删除不需要的作品释放空间"
)
return
}
response
.
Error
(
c
,
http
.
StatusForbidden
,
err
.
Error
())
return
}
}
// 获取 API Key ID 和 Group ID
var
apiKeyID
*
int64
var
groupID
*
int64
if
req
.
APIKeyID
!=
nil
&&
h
.
apiKeyService
!=
nil
{
// 前端传递了 api_key_id,需要校验
apiKey
,
err
:=
h
.
apiKeyService
.
GetByID
(
c
.
Request
.
Context
(),
*
req
.
APIKeyID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"API Key 不存在"
)
return
}
if
apiKey
.
UserID
!=
userID
{
response
.
Error
(
c
,
http
.
StatusForbidden
,
"API Key 不属于当前用户"
)
return
}
if
apiKey
.
Status
!=
service
.
StatusAPIKeyActive
{
response
.
Error
(
c
,
http
.
StatusForbidden
,
"API Key 不可用"
)
return
}
apiKeyID
=
&
apiKey
.
ID
groupID
=
apiKey
.
GroupID
}
else
if
id
,
ok
:=
c
.
Get
(
"api_key_id"
);
ok
{
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if
v
,
ok
:=
id
.
(
int64
);
ok
{
apiKeyID
=
&
v
}
}
gen
,
err
:=
h
.
genService
.
CreatePending
(
c
.
Request
.
Context
(),
userID
,
apiKeyID
,
req
.
Model
,
req
.
Prompt
,
req
.
MediaType
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationConcurrencyLimit
)
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"同时进行中的任务不能超过 3 个"
)
return
}
response
.
ErrorFrom
(
c
,
err
)
return
}
// 启动后台异步生成 goroutine
go
h
.
processGeneration
(
gen
.
ID
,
userID
,
groupID
,
req
.
Model
,
req
.
Prompt
,
req
.
MediaType
,
req
.
ImageInput
,
req
.
VideoCount
)
response
.
Success
(
c
,
gin
.
H
{
"generation_id"
:
gen
.
ID
,
"status"
:
gen
.
Status
,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
func
(
h
*
SoraClientHandler
)
processGeneration
(
genID
int64
,
userID
int64
,
groupID
*
int64
,
model
,
prompt
,
mediaType
,
imageInput
string
,
videoCount
int
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Minute
)
defer
cancel
()
// 标记为生成中
if
err
:=
h
.
genService
.
MarkGenerating
(
ctx
,
genID
,
""
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationStateConflict
)
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 任务状态已变化,跳过生成 id=%d"
,
genID
)
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 标记生成中失败 id=%d err=%v"
,
genID
,
err
)
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d"
,
genID
,
userID
,
groupIDForLog
(
groupID
),
model
,
mediaType
,
videoCount
,
strings
.
TrimSpace
(
imageInput
)
!=
""
,
len
(
strings
.
TrimSpace
(
prompt
)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if
groupID
==
nil
{
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
service
.
PlatformSora
)
}
if
h
.
gatewayService
==
nil
{
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"内部错误: gatewayService 未初始化"
)
return
}
// 选择 Sora 账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
ctx
,
groupID
,
""
,
model
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v"
,
genID
,
userID
,
groupIDForLog
(
groupID
),
model
,
err
,
)
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"选择账号失败: "
+
err
.
Error
())
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s"
,
genID
,
userID
,
groupIDForLog
(
groupID
),
model
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
account
.
Type
,
)
// 构建 chat completions 请求体(非流式)
body
:=
buildAsyncRequestBody
(
model
,
prompt
,
imageInput
,
normalizeVideoCount
(
mediaType
,
videoCount
))
if
h
.
soraGatewayService
==
nil
{
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"内部错误: soraGatewayService 未初始化"
)
return
}
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
recorder
:=
httptest
.
NewRecorder
()
mockGinCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
mockGinCtx
.
Request
,
_
=
http
.
NewRequest
(
"POST"
,
"/"
,
nil
)
// 调用 Forward(非流式)
result
,
err
:=
h
.
soraGatewayService
.
Forward
(
ctx
,
mockGinCtx
,
account
,
body
,
false
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v"
,
genID
,
account
.
ID
,
model
,
recorder
.
Code
,
trimForLog
(
recorder
.
Body
.
String
(),
400
),
err
,
)
// 检查是否已取消
gen
,
_
:=
h
.
genService
.
GetByID
(
ctx
,
genID
,
userID
)
if
gen
!=
nil
&&
gen
.
Status
==
service
.
SoraGenStatusCancelled
{
return
}
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"生成失败: "
+
err
.
Error
())
return
}
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
mediaURL
,
mediaURLs
:=
extractMediaURLsFromResult
(
result
,
recorder
)
if
mediaURL
==
""
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s"
,
genID
,
account
.
ID
,
model
,
recorder
.
Code
,
trimForLog
(
recorder
.
Body
.
String
(),
400
),
)
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"未获取到媒体 URL"
)
return
}
// 检查任务是否已被取消
gen
,
_
:=
h
.
genService
.
GetByID
(
ctx
,
genID
,
userID
)
if
gen
!=
nil
&&
gen
.
Status
==
service
.
SoraGenStatusCancelled
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 任务已取消,跳过存储 id=%d"
,
genID
)
return
}
// 三层降级存储:S3 → 本地 → 上游临时 URL
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
:=
h
.
storeMediaWithDegradation
(
ctx
,
userID
,
mediaType
,
mediaURL
,
mediaURLs
)
usageAdded
:=
false
if
(
storageType
==
service
.
SoraStorageTypeS3
||
storageType
==
service
.
SoraStorageTypeLocal
)
&&
fileSize
>
0
&&
h
.
quotaService
!=
nil
{
if
err
:=
h
.
quotaService
.
AddUsage
(
ctx
,
userID
,
fileSize
);
err
!=
nil
{
h
.
cleanupStoredMedia
(
ctx
,
storageType
,
s3Keys
,
storedURLs
)
var
quotaErr
*
service
.
QuotaExceededError
if
errors
.
As
(
err
,
&
quotaErr
)
{
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"存储配额已满,请删除不需要的作品释放空间"
)
return
}
_
=
h
.
genService
.
MarkFailed
(
ctx
,
genID
,
"存储配额更新失败: "
+
err
.
Error
())
return
}
usageAdded
=
true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen
,
_
=
h
.
genService
.
GetByID
(
ctx
,
genID
,
userID
)
if
gen
!=
nil
&&
gen
.
Status
==
service
.
SoraGenStatusCancelled
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d"
,
genID
)
h
.
cleanupStoredMedia
(
ctx
,
storageType
,
s3Keys
,
storedURLs
)
if
usageAdded
&&
h
.
quotaService
!=
nil
{
_
=
h
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
fileSize
)
}
return
}
// 标记完成
if
err
:=
h
.
genService
.
MarkCompleted
(
ctx
,
genID
,
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationStateConflict
)
{
h
.
cleanupStoredMedia
(
ctx
,
storageType
,
s3Keys
,
storedURLs
)
if
usageAdded
&&
h
.
quotaService
!=
nil
{
_
=
h
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
fileSize
)
}
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 标记完成失败 id=%d err=%v"
,
genID
,
err
)
return
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 生成完成 id=%d storage=%s size=%d"
,
genID
,
storageType
,
fileSize
)
}
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
func
(
h
*
SoraClientHandler
)
storeMediaWithDegradation
(
ctx
context
.
Context
,
userID
int64
,
mediaType
string
,
mediaURL
string
,
mediaURLs
[]
string
,
)
(
storedURL
string
,
storedURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSize
int64
)
{
urls
:=
mediaURLs
if
len
(
urls
)
==
0
{
urls
=
[]
string
{
mediaURL
}
}
// 第一层:尝试 S3
if
h
.
s3Storage
!=
nil
&&
h
.
s3Storage
.
Enabled
(
ctx
)
{
keys
:=
make
([]
string
,
0
,
len
(
urls
))
var
totalSize
int64
allOK
:=
true
for
_
,
u
:=
range
urls
{
key
,
size
,
err
:=
h
.
s3Storage
.
UploadFromURL
(
ctx
,
userID
,
u
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] S3 上传失败 err=%v"
,
err
)
allOK
=
false
// 清理已上传的文件
if
len
(
keys
)
>
0
{
_
=
h
.
s3Storage
.
DeleteObjects
(
ctx
,
keys
)
}
break
}
keys
=
append
(
keys
,
key
)
totalSize
+=
size
}
if
allOK
&&
len
(
keys
)
>
0
{
accessURLs
:=
make
([]
string
,
0
,
len
(
keys
))
for
_
,
key
:=
range
keys
{
accessURL
,
err
:=
h
.
s3Storage
.
GetAccessURL
(
ctx
,
key
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 生成 S3 访问 URL 失败 err=%v"
,
err
)
_
=
h
.
s3Storage
.
DeleteObjects
(
ctx
,
keys
)
allOK
=
false
break
}
accessURLs
=
append
(
accessURLs
,
accessURL
)
}
if
allOK
&&
len
(
accessURLs
)
>
0
{
return
accessURLs
[
0
],
accessURLs
,
service
.
SoraStorageTypeS3
,
keys
,
totalSize
}
}
}
// 第二层:尝试本地存储
if
h
.
mediaStorage
!=
nil
&&
h
.
mediaStorage
.
Enabled
()
{
storedPaths
,
err
:=
h
.
mediaStorage
.
StoreFromURLs
(
ctx
,
mediaType
,
urls
)
if
err
==
nil
&&
len
(
storedPaths
)
>
0
{
firstPath
:=
storedPaths
[
0
]
totalSize
,
sizeErr
:=
h
.
mediaStorage
.
TotalSizeByRelativePaths
(
storedPaths
)
if
sizeErr
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 统计本地文件大小失败 err=%v"
,
sizeErr
)
}
return
firstPath
,
storedPaths
,
service
.
SoraStorageTypeLocal
,
nil
,
totalSize
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 本地存储失败 err=%v"
,
err
)
}
// 第三层:保留上游临时 URL
return
urls
[
0
],
urls
,
service
.
SoraStorageTypeUpstream
,
nil
,
0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func
buildAsyncRequestBody
(
model
,
prompt
,
imageInput
string
,
videoCount
int
)
[]
byte
{
body
:=
map
[
string
]
any
{
"model"
:
model
,
"messages"
:
[]
map
[
string
]
string
{
{
"role"
:
"user"
,
"content"
:
prompt
},
},
"stream"
:
false
,
}
if
imageInput
!=
""
{
body
[
"image_input"
]
=
imageInput
}
if
videoCount
>
1
{
body
[
"video_count"
]
=
videoCount
}
b
,
_
:=
json
.
Marshal
(
body
)
return
b
}
func
normalizeVideoCount
(
mediaType
string
,
videoCount
int
)
int
{
if
mediaType
!=
"video"
{
return
1
}
if
videoCount
<=
0
{
return
1
}
if
videoCount
>
3
{
return
3
}
return
videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径:ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func
extractMediaURLsFromResult
(
result
*
service
.
ForwardResult
,
recorder
*
httptest
.
ResponseRecorder
)
(
string
,
[]
string
)
{
// 优先从 ForwardResult 获取(OAuth 路径)
if
result
!=
nil
&&
result
.
MediaURL
!=
""
{
// 尝试从响应体获取完整 URL 列表
if
urls
:=
parseMediaURLsFromBody
(
recorder
.
Body
.
Bytes
());
len
(
urls
)
>
0
{
return
urls
[
0
],
urls
}
return
result
.
MediaURL
,
[]
string
{
result
.
MediaURL
}
}
// 从响应体解析(APIKey 路径)
if
urls
:=
parseMediaURLsFromBody
(
recorder
.
Body
.
Bytes
());
len
(
urls
)
>
0
{
return
urls
[
0
],
urls
}
return
""
,
nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func
parseMediaURLsFromBody
(
body
[]
byte
)
[]
string
{
if
len
(
body
)
==
0
{
return
nil
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
}
// 优先 media_urls(多图数组)
if
rawURLs
,
ok
:=
resp
[
"media_urls"
];
ok
{
if
arr
,
ok
:=
rawURLs
.
([]
any
);
ok
&&
len
(
arr
)
>
0
{
urls
:=
make
([]
string
,
0
,
len
(
arr
))
for
_
,
item
:=
range
arr
{
if
s
,
ok
:=
item
.
(
string
);
ok
&&
s
!=
""
{
urls
=
append
(
urls
,
s
)
}
}
if
len
(
urls
)
>
0
{
return
urls
}
}
}
// 回退到 media_url(单个 URL)
if
url
,
ok
:=
resp
[
"media_url"
]
.
(
string
);
ok
&&
url
!=
""
{
return
[]
string
{
url
}
}
return
nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func
(
h
*
SoraClientHandler
)
ListGenerations
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
page
,
_
:=
strconv
.
Atoi
(
c
.
DefaultQuery
(
"page"
,
"1"
))
pageSize
,
_
:=
strconv
.
Atoi
(
c
.
DefaultQuery
(
"page_size"
,
"20"
))
params
:=
service
.
SoraGenerationListParams
{
UserID
:
userID
,
Status
:
c
.
Query
(
"status"
),
StorageType
:
c
.
Query
(
"storage_type"
),
MediaType
:
c
.
Query
(
"media_type"
),
Page
:
page
,
PageSize
:
pageSize
,
}
gens
,
total
,
err
:=
h
.
genService
.
List
(
c
.
Request
.
Context
(),
params
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// 为 S3 记录动态生成预签名 URL
for
_
,
gen
:=
range
gens
{
_
=
h
.
genService
.
ResolveMediaURLs
(
c
.
Request
.
Context
(),
gen
)
}
response
.
Success
(
c
,
gin
.
H
{
"data"
:
gens
,
"total"
:
total
,
"page"
:
page
,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func
(
h
*
SoraClientHandler
)
GetGeneration
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
_
=
h
.
genService
.
ResolveMediaURLs
(
c
.
Request
.
Context
(),
gen
)
response
.
Success
(
c
,
gen
)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func
(
h
*
SoraClientHandler
)
DeleteGeneration
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if
gen
.
StorageType
==
service
.
SoraStorageTypeLocal
&&
h
.
mediaStorage
!=
nil
{
paths
:=
gen
.
MediaURLs
if
len
(
paths
)
==
0
&&
gen
.
MediaURL
!=
""
{
paths
=
[]
string
{
gen
.
MediaURL
}
}
if
err
:=
h
.
mediaStorage
.
DeleteByRelativePaths
(
paths
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 删除本地文件失败 id=%d err=%v"
,
id
,
err
)
}
}
if
err
:=
h
.
genService
.
Delete
(
c
.
Request
.
Context
(),
id
,
userID
);
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"已删除"
})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func
(
h
*
SoraClientHandler
)
GetQuota
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
if
h
.
quotaService
==
nil
{
response
.
Success
(
c
,
service
.
QuotaInfo
{
QuotaSource
:
"unlimited"
,
Source
:
"unlimited"
})
return
}
quota
,
err
:=
h
.
quotaService
.
GetQuota
(
c
.
Request
.
Context
(),
userID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
quota
)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func
(
h
*
SoraClientHandler
)
CancelGeneration
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
// 权限校验
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
_
=
gen
if
err
:=
h
.
genService
.
MarkCancelled
(
c
.
Request
.
Context
(),
id
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrSoraGenerationNotActive
)
{
response
.
Error
(
c
,
http
.
StatusConflict
,
"任务已结束,无法取消"
)
return
}
response
.
Error
(
c
,
http
.
StatusBadRequest
,
err
.
Error
())
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"已取消"
})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func
(
h
*
SoraClientHandler
)
SaveToStorage
(
c
*
gin
.
Context
)
{
userID
:=
getUserIDFromContext
(
c
)
if
userID
==
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"未登录"
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"无效的 ID"
)
return
}
gen
,
err
:=
h
.
genService
.
GetByID
(
c
.
Request
.
Context
(),
id
,
userID
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusNotFound
,
err
.
Error
())
return
}
if
gen
.
StorageType
!=
service
.
SoraStorageTypeUpstream
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"仅 upstream 类型的记录可手动保存"
)
return
}
if
gen
.
MediaURL
==
""
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"媒体 URL 为空,可能已过期"
)
return
}
if
h
.
s3Storage
==
nil
||
!
h
.
s3Storage
.
Enabled
(
c
.
Request
.
Context
())
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"云存储未配置,请联系管理员"
)
return
}
sourceURLs
:=
gen
.
MediaURLs
if
len
(
sourceURLs
)
==
0
&&
gen
.
MediaURL
!=
""
{
sourceURLs
=
[]
string
{
gen
.
MediaURL
}
}
if
len
(
sourceURLs
)
==
0
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
"媒体 URL 为空,可能已过期"
)
return
}
uploadedKeys
:=
make
([]
string
,
0
,
len
(
sourceURLs
))
accessURLs
:=
make
([]
string
,
0
,
len
(
sourceURLs
))
var
totalSize
int64
for
_
,
sourceURL
:=
range
sourceURLs
{
objectKey
,
fileSize
,
uploadErr
:=
h
.
s3Storage
.
UploadFromURL
(
c
.
Request
.
Context
(),
userID
,
sourceURL
)
if
uploadErr
!=
nil
{
if
len
(
uploadedKeys
)
>
0
{
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
}
var
upstreamErr
*
service
.
UpstreamDownloadError
if
errors
.
As
(
uploadErr
,
&
upstreamErr
)
&&
(
upstreamErr
.
StatusCode
==
http
.
StatusForbidden
||
upstreamErr
.
StatusCode
==
http
.
StatusNotFound
)
{
response
.
Error
(
c
,
http
.
StatusGone
,
"媒体链接已过期,无法保存"
)
return
}
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"上传到 S3 失败: "
+
uploadErr
.
Error
())
return
}
accessURL
,
err
:=
h
.
s3Storage
.
GetAccessURL
(
c
.
Request
.
Context
(),
objectKey
)
if
err
!=
nil
{
uploadedKeys
=
append
(
uploadedKeys
,
objectKey
)
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"生成 S3 访问链接失败: "
+
err
.
Error
())
return
}
uploadedKeys
=
append
(
uploadedKeys
,
objectKey
)
accessURLs
=
append
(
accessURLs
,
accessURL
)
totalSize
+=
fileSize
}
usageAdded
:=
false
if
totalSize
>
0
&&
h
.
quotaService
!=
nil
{
if
err
:=
h
.
quotaService
.
AddUsage
(
c
.
Request
.
Context
(),
userID
,
totalSize
);
err
!=
nil
{
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
var
quotaErr
*
service
.
QuotaExceededError
if
errors
.
As
(
err
,
&
quotaErr
)
{
response
.
Error
(
c
,
http
.
StatusTooManyRequests
,
"存储配额已满,请删除不需要的作品释放空间"
)
return
}
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"配额更新失败: "
+
err
.
Error
())
return
}
usageAdded
=
true
}
if
err
:=
h
.
genService
.
UpdateStorageForCompleted
(
c
.
Request
.
Context
(),
id
,
accessURLs
[
0
],
accessURLs
,
service
.
SoraStorageTypeS3
,
uploadedKeys
,
totalSize
,
);
err
!=
nil
{
_
=
h
.
s3Storage
.
DeleteObjects
(
c
.
Request
.
Context
(),
uploadedKeys
)
if
usageAdded
&&
h
.
quotaService
!=
nil
{
_
=
h
.
quotaService
.
ReleaseUsage
(
c
.
Request
.
Context
(),
userID
,
totalSize
)
}
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"已保存到 S3"
,
"object_key"
:
uploadedKeys
[
0
],
"object_keys"
:
uploadedKeys
,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func
(
h
*
SoraClientHandler
)
GetStorageStatus
(
c
*
gin
.
Context
)
{
s3Enabled
:=
h
.
s3Storage
!=
nil
&&
h
.
s3Storage
.
Enabled
(
c
.
Request
.
Context
())
s3Healthy
:=
false
if
s3Enabled
{
s3Healthy
=
h
.
s3Storage
.
IsHealthy
(
c
.
Request
.
Context
())
}
localEnabled
:=
h
.
mediaStorage
!=
nil
&&
h
.
mediaStorage
.
Enabled
()
response
.
Success
(
c
,
gin
.
H
{
"s3_enabled"
:
s3Enabled
,
"s3_healthy"
:
s3Healthy
,
"local_enabled"
:
localEnabled
,
})
}
func
(
h
*
SoraClientHandler
)
cleanupStoredMedia
(
ctx
context
.
Context
,
storageType
string
,
s3Keys
[]
string
,
localPaths
[]
string
)
{
switch
storageType
{
case
service
.
SoraStorageTypeS3
:
if
h
.
s3Storage
!=
nil
&&
len
(
s3Keys
)
>
0
{
if
err
:=
h
.
s3Storage
.
DeleteObjects
(
ctx
,
s3Keys
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 清理 S3 文件失败 keys=%v err=%v"
,
s3Keys
,
err
)
}
}
case
service
.
SoraStorageTypeLocal
:
if
h
.
mediaStorage
!=
nil
&&
len
(
localPaths
)
>
0
{
if
err
:=
h
.
mediaStorage
.
DeleteByRelativePaths
(
localPaths
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 清理本地文件失败 paths=%v err=%v"
,
localPaths
,
err
)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func
getUserIDFromContext
(
c
*
gin
.
Context
)
int64
{
if
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
);
ok
&&
subject
.
UserID
>
0
{
return
subject
.
UserID
}
if
id
,
ok
:=
c
.
Get
(
"user_id"
);
ok
{
switch
v
:=
id
.
(
type
)
{
case
int64
:
return
v
case
float64
:
return
int64
(
v
)
case
string
:
n
,
_
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
return
n
}
}
// 尝试从 JWT claims 获取
if
id
,
ok
:=
c
.
Get
(
"userID"
);
ok
{
if
v
,
ok
:=
id
.
(
int64
);
ok
{
return
v
}
}
return
0
}
func
groupIDForLog
(
groupID
*
int64
)
int64
{
if
groupID
==
nil
{
return
0
}
return
*
groupID
}
func
trimForLog
(
raw
string
,
maxLen
int
)
string
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
maxLen
<=
0
||
len
(
trimmed
)
<=
maxLen
{
return
trimmed
}
return
trimmed
[
:
maxLen
]
+
"...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func
(
h
*
SoraClientHandler
)
GetModels
(
c
*
gin
.
Context
)
{
families
:=
h
.
getModelFamilies
(
c
.
Request
.
Context
())
response
.
Success
(
c
,
families
)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func
(
h
*
SoraClientHandler
)
getModelFamilies
(
ctx
context
.
Context
)
[]
service
.
SoraModelFamily
{
// 读锁检查缓存
h
.
modelCacheMu
.
RLock
()
ttl
:=
modelCacheTTL
if
!
h
.
modelCacheUpstream
{
ttl
=
modelCacheFailedTTL
}
if
h
.
cachedFamilies
!=
nil
&&
time
.
Since
(
h
.
modelCacheTime
)
<
ttl
{
families
:=
h
.
cachedFamilies
h
.
modelCacheMu
.
RUnlock
()
return
families
}
h
.
modelCacheMu
.
RUnlock
()
// 写锁更新缓存
h
.
modelCacheMu
.
Lock
()
defer
h
.
modelCacheMu
.
Unlock
()
// double-check
ttl
=
modelCacheTTL
if
!
h
.
modelCacheUpstream
{
ttl
=
modelCacheFailedTTL
}
if
h
.
cachedFamilies
!=
nil
&&
time
.
Since
(
h
.
modelCacheTime
)
<
ttl
{
return
h
.
cachedFamilies
}
// 尝试从上游获取
families
,
err
:=
h
.
fetchUpstreamModels
(
ctx
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 上游模型获取失败,使用本地配置: %v"
,
err
)
families
=
service
.
BuildSoraModelFamilies
()
h
.
cachedFamilies
=
families
h
.
modelCacheTime
=
time
.
Now
()
h
.
modelCacheUpstream
=
false
return
families
}
logger
.
LegacyPrintf
(
"handler.sora_client"
,
"[SoraClient] 从上游同步到 %d 个模型家族"
,
len
(
families
))
h
.
cachedFamilies
=
families
h
.
modelCacheTime
=
time
.
Now
()
h
.
modelCacheUpstream
=
true
return
families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func
(
h
*
SoraClientHandler
)
fetchUpstreamModels
(
ctx
context
.
Context
)
([]
service
.
SoraModelFamily
,
error
)
{
if
h
.
gatewayService
==
nil
{
return
nil
,
fmt
.
Errorf
(
"gatewayService 未初始化"
)
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
service
.
PlatformSora
)
// 选择一个 Sora 账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
ctx
,
nil
,
""
,
"sora2-landscape-10s"
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"选择 Sora 账号失败: %w"
,
err
)
}
// 仅支持 API Key 类型账号
if
account
.
Type
!=
service
.
AccountTypeAPIKey
{
return
nil
,
fmt
.
Errorf
(
"当前账号类型 %s 不支持模型同步"
,
account
.
Type
)
}
apiKey
:=
account
.
GetCredential
(
"api_key"
)
if
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"账号缺少 api_key"
)
}
baseURL
:=
account
.
GetBaseURL
()
if
baseURL
==
""
{
return
nil
,
fmt
.
Errorf
(
"账号缺少 base_url"
)
}
// 构建上游模型列表请求
modelsURL
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
"/sora/v1/models"
reqCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
10
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
reqCtx
,
http
.
MethodGet
,
modelsURL
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
client
:=
&
http
.
Client
{
Timeout
:
10
*
time
.
Second
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求上游失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"上游返回状态码 %d"
,
resp
.
StatusCode
)
}
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
1
*
1024
*
1024
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
// 解析 OpenAI 格式的模型列表
var
modelsResp
struct
{
Data
[]
struct
{
ID
string
`json:"id"`
}
`json:"data"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
modelsResp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"解析响应失败: %w"
,
err
)
}
if
len
(
modelsResp
.
Data
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"上游返回空模型列表"
)
}
// 提取模型 ID
modelIDs
:=
make
([]
string
,
0
,
len
(
modelsResp
.
Data
))
for
_
,
m
:=
range
modelsResp
.
Data
{
modelIDs
=
append
(
modelIDs
,
m
.
ID
)
}
// 转换为模型家族
families
:=
service
.
BuildSoraModelFamiliesFromIDs
(
modelIDs
)
if
len
(
families
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"未能从上游模型列表中识别出有效的模型家族"
)
}
return
families
,
nil
}
backend/internal/handler/sora_client_handler_test.go
deleted
100644 → 0
View file @
dbb248df
//go:build unit
package
handler
import
(
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
init
()
{
gin
.
SetMode
(
gin
.
TestMode
)
}
// ==================== Stub: SoraGenerationRepository ====================
var
_
service
.
SoraGenerationRepository
=
(
*
stubSoraGenRepo
)(
nil
)
type
stubSoraGenRepo
struct
{
gens
map
[
int64
]
*
service
.
SoraGeneration
nextID
int64
createErr
error
getErr
error
updateErr
error
deleteErr
error
listErr
error
countErr
error
countValue
int64
// 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
updateCallCount
*
int32
updateFailAfterN
int32
// 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
getByIDCallCount
int32
getByIDOverrideAfterN
int32
// 0 = 不覆盖
getByIDOverrideStatus
string
}
func
newStubSoraGenRepo
()
*
stubSoraGenRepo
{
return
&
stubSoraGenRepo
{
gens
:
make
(
map
[
int64
]
*
service
.
SoraGeneration
),
nextID
:
1
}
}
func
(
r
*
stubSoraGenRepo
)
Create
(
_
context
.
Context
,
gen
*
service
.
SoraGeneration
)
error
{
if
r
.
createErr
!=
nil
{
return
r
.
createErr
}
gen
.
ID
=
r
.
nextID
r
.
nextID
++
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubSoraGenRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
SoraGeneration
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
gen
,
ok
:=
r
.
gens
[
id
]
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"not found"
)
}
// 条件性状态覆盖:模拟外部取消等场景
if
r
.
getByIDOverrideAfterN
>
0
{
n
:=
atomic
.
AddInt32
(
&
r
.
getByIDCallCount
,
1
)
if
n
>
r
.
getByIDOverrideAfterN
{
cp
:=
*
gen
cp
.
Status
=
r
.
getByIDOverrideStatus
return
&
cp
,
nil
}
}
return
gen
,
nil
}
func
(
r
*
stubSoraGenRepo
)
Update
(
_
context
.
Context
,
gen
*
service
.
SoraGeneration
)
error
{
// 条件性失败:前 N 次成功,之后失败
if
r
.
updateCallCount
!=
nil
{
n
:=
atomic
.
AddInt32
(
r
.
updateCallCount
,
1
)
if
n
>
r
.
updateFailAfterN
{
return
fmt
.
Errorf
(
"conditional update error (call #%d)"
,
n
)
}
}
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubSoraGenRepo
)
Delete
(
_
context
.
Context
,
id
int64
)
error
{
if
r
.
deleteErr
!=
nil
{
return
r
.
deleteErr
}
delete
(
r
.
gens
,
id
)
return
nil
}
func
(
r
*
stubSoraGenRepo
)
List
(
_
context
.
Context
,
params
service
.
SoraGenerationListParams
)
([]
*
service
.
SoraGeneration
,
int64
,
error
)
{
if
r
.
listErr
!=
nil
{
return
nil
,
0
,
r
.
listErr
}
var
result
[]
*
service
.
SoraGeneration
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
!=
params
.
UserID
{
continue
}
result
=
append
(
result
,
gen
)
}
return
result
,
int64
(
len
(
result
)),
nil
}
func
(
r
*
stubSoraGenRepo
)
CountByUserAndStatus
(
_
context
.
Context
,
_
int64
,
_
[]
string
)
(
int64
,
error
)
{
if
r
.
countErr
!=
nil
{
return
0
,
r
.
countErr
}
return
r
.
countValue
,
nil
}
// ==================== 辅助函数 ====================
func
newTestSoraClientHandler
(
repo
*
stubSoraGenRepo
)
*
SoraClientHandler
{
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
return
&
SoraClientHandler
{
genService
:
genService
}
}
func
makeGinContext
(
method
,
path
,
body
string
,
userID
int64
)
(
*
gin
.
Context
,
*
httptest
.
ResponseRecorder
)
{
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
if
body
!=
""
{
c
.
Request
=
httptest
.
NewRequest
(
method
,
path
,
strings
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
}
else
{
c
.
Request
=
httptest
.
NewRequest
(
method
,
path
,
nil
)
}
if
userID
>
0
{
c
.
Set
(
"user_id"
,
userID
)
}
return
c
,
rec
}
func
parseResponse
(
t
*
testing
.
T
,
rec
*
httptest
.
ResponseRecorder
)
map
[
string
]
any
{
t
.
Helper
()
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
return
resp
}
// ==================== 纯函数测试: buildAsyncRequestBody ====================
func
TestBuildAsyncRequestBody
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"sora2-landscape-10s"
,
"一只猫在跳舞"
,
""
,
1
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
parsed
[
"model"
])
require
.
Equal
(
t
,
false
,
parsed
[
"stream"
])
msgs
:=
parsed
[
"messages"
]
.
([]
any
)
require
.
Len
(
t
,
msgs
,
1
)
msg
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"user"
,
msg
[
"role"
])
require
.
Equal
(
t
,
"一只猫在跳舞"
,
msg
[
"content"
])
}
func
TestBuildAsyncRequestBody_EmptyPrompt
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"gpt-image"
,
""
,
""
,
1
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
"gpt-image"
,
parsed
[
"model"
])
msgs
:=
parsed
[
"messages"
]
.
([]
any
)
msg
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
""
,
msg
[
"content"
])
}
func
TestBuildAsyncRequestBody_WithImageInput
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"gpt-image"
,
"一只猫"
,
"https://example.com/ref.png"
,
1
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
"https://example.com/ref.png"
,
parsed
[
"image_input"
])
}
func
TestBuildAsyncRequestBody_WithVideoCount
(
t
*
testing
.
T
)
{
body
:=
buildAsyncRequestBody
(
"sora2-landscape-10s"
,
"一只猫在跳舞"
,
""
,
3
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
body
,
&
parsed
))
require
.
Equal
(
t
,
float64
(
3
),
parsed
[
"video_count"
])
}
func
TestNormalizeVideoCount
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
1
,
normalizeVideoCount
(
"video"
,
0
))
require
.
Equal
(
t
,
2
,
normalizeVideoCount
(
"video"
,
2
))
require
.
Equal
(
t
,
3
,
normalizeVideoCount
(
"video"
,
5
))
require
.
Equal
(
t
,
1
,
normalizeVideoCount
(
"image"
,
3
))
}
// ==================== 纯函数测试: parseMediaURLsFromBody ====================
func
TestParseMediaURLsFromBody_MediaURLs
(
t
*
testing
.
T
)
{
urls
:=
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`
))
require
.
Equal
(
t
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
urls
)
}
func
TestParseMediaURLsFromBody_SingleMediaURL
(
t
*
testing
.
T
)
{
urls
:=
parseMediaURLsFromBody
([]
byte
(
`{"media_url":"https://a.com/video.mp4"}`
))
require
.
Equal
(
t
,
[]
string
{
"https://a.com/video.mp4"
},
urls
)
}
func
TestParseMediaURLsFromBody_EmptyBody
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
(
nil
))
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
{}))
}
func
TestParseMediaURLsFromBody_InvalidJSON
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
"not json"
)))
}
func
TestParseMediaURLsFromBody_NoMediaFields
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"data":"something"}`
)))
}
func
TestParseMediaURLsFromBody_EmptyMediaURL
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_url":""}`
)))
}
func
TestParseMediaURLsFromBody_EmptyMediaURLs
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":[]}`
)))
}
func
TestParseMediaURLsFromBody_MediaURLsPriority
(
t
*
testing
.
T
)
{
body
:=
`{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
urls
:=
parseMediaURLsFromBody
([]
byte
(
body
))
require
.
Len
(
t
,
urls
,
2
)
require
.
Equal
(
t
,
"https://multi.com/a.mp4"
,
urls
[
0
])
}
func
TestParseMediaURLsFromBody_FilterEmpty
(
t
*
testing
.
T
)
{
urls
:=
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`
))
require
.
Equal
(
t
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
urls
)
}
func
TestParseMediaURLsFromBody_AllEmpty
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":["",""]}`
)))
}
func
TestParseMediaURLsFromBody_NonStringArray
(
t
*
testing
.
T
)
{
// media_urls 不是 string 数组
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_urls":"not-array"}`
)))
}
func
TestParseMediaURLsFromBody_MediaURLNotString
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
parseMediaURLsFromBody
([]
byte
(
`{"media_url":123}`
)))
}
// ==================== 纯函数测试: extractMediaURLsFromResult ====================
func
TestExtractMediaURLsFromResult_OAuthPath
(
t
*
testing
.
T
)
{
result
:=
&
service
.
ForwardResult
{
MediaURL
:
"https://oauth.com/video.mp4"
}
recorder
:=
httptest
.
NewRecorder
()
url
,
urls
:=
extractMediaURLsFromResult
(
result
,
recorder
)
require
.
Equal
(
t
,
"https://oauth.com/video.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://oauth.com/video.mp4"
},
urls
)
}
func
TestExtractMediaURLsFromResult_OAuthWithBody
(
t
*
testing
.
T
)
{
result
:=
&
service
.
ForwardResult
{
MediaURL
:
"https://oauth.com/video.mp4"
}
recorder
:=
httptest
.
NewRecorder
()
_
,
_
=
recorder
.
Write
([]
byte
(
`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`
))
url
,
urls
:=
extractMediaURLsFromResult
(
result
,
recorder
)
require
.
Equal
(
t
,
"https://body.com/1.mp4"
,
url
)
require
.
Len
(
t
,
urls
,
2
)
}
func
TestExtractMediaURLsFromResult_APIKeyPath
(
t
*
testing
.
T
)
{
recorder
:=
httptest
.
NewRecorder
()
_
,
_
=
recorder
.
Write
([]
byte
(
`{"media_url":"https://upstream.com/video.mp4"}`
))
url
,
urls
:=
extractMediaURLsFromResult
(
nil
,
recorder
)
require
.
Equal
(
t
,
"https://upstream.com/video.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://upstream.com/video.mp4"
},
urls
)
}
func
TestExtractMediaURLsFromResult_NilResultEmptyBody
(
t
*
testing
.
T
)
{
recorder
:=
httptest
.
NewRecorder
()
url
,
urls
:=
extractMediaURLsFromResult
(
nil
,
recorder
)
require
.
Empty
(
t
,
url
)
require
.
Nil
(
t
,
urls
)
}
func
TestExtractMediaURLsFromResult_EmptyMediaURL
(
t
*
testing
.
T
)
{
result
:=
&
service
.
ForwardResult
{
MediaURL
:
""
}
recorder
:=
httptest
.
NewRecorder
()
url
,
urls
:=
extractMediaURLsFromResult
(
result
,
recorder
)
require
.
Empty
(
t
,
url
)
require
.
Nil
(
t
,
urls
)
}
// ==================== getUserIDFromContext ====================
func
TestGetUserIDFromContext_Int64
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
int64
(
42
))
require
.
Equal
(
t
,
int64
(
42
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_AuthSubject
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
777
})
require
.
Equal
(
t
,
int64
(
777
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_Float64
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
float64
(
99
))
require
.
Equal
(
t
,
int64
(
99
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_String
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
"123"
)
require
.
Equal
(
t
,
int64
(
123
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_UserIDFallback
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"userID"
,
int64
(
55
))
require
.
Equal
(
t
,
int64
(
55
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_NoID
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
require
.
Equal
(
t
,
int64
(
0
),
getUserIDFromContext
(
c
))
}
func
TestGetUserIDFromContext_InvalidString
(
t
*
testing
.
T
)
{
c
,
_
:=
gin
.
CreateTestContext
(
httptest
.
NewRecorder
())
c
.
Request
=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
c
.
Set
(
"user_id"
,
"not-a-number"
)
require
.
Equal
(
t
,
int64
(
0
),
getUserIDFromContext
(
c
))
}
// ==================== Handler: Generate ====================
func
TestGenerate_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
0
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestGenerate_BadRequest_MissingModel
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGenerate_BadRequest_MissingPrompt
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGenerate_BadRequest_InvalidJSON
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{invalid`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGenerate_TooManyRequests
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
countValue
=
3
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
}
func
TestGenerate_CountError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
countErr
=
fmt
.
Errorf
(
"db error"
)
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
func
TestGenerate_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"测试生成"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
NotZero
(
t
,
data
[
"generation_id"
])
require
.
Equal
(
t
,
"pending"
,
data
[
"status"
])
}
func
TestGenerate_DefaultMediaType
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"video"
,
repo
.
gens
[
1
]
.
MediaType
)
}
func
TestGenerate_ImageMediaType
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"gpt-image","prompt":"test","media_type":"image"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"image"
,
repo
.
gens
[
1
]
.
MediaType
)
}
func
TestGenerate_CreatePendingError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
createErr
=
fmt
.
Errorf
(
"create failed"
)
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
func
TestGenerate_NilQuotaServiceSkipsCheck
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestGenerate_APIKeyInContext
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
c
.
Set
(
"api_key_id"
,
int64
(
42
))
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_NoAPIKeyInContext
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Nil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_ConcurrencyBoundary
(
t
*
testing
.
T
)
{
// activeCount == 2 应该允许
repo
:=
newStubSoraGenRepo
()
repo
.
countValue
=
2
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
// ==================== Handler: ListGenerations ====================
func
TestListGenerations_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations"
,
""
,
0
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestListGenerations_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Model
:
"sora2-landscape-10s"
,
Status
:
"completed"
,
StorageType
:
"upstream"
}
repo
.
gens
[
2
]
=
&
service
.
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Model
:
"gpt-image"
,
Status
:
"pending"
,
StorageType
:
"none"
}
repo
.
nextID
=
3
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations?page=1&page_size=10"
,
""
,
1
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
items
:=
data
[
"data"
]
.
([]
any
)
require
.
Len
(
t
,
items
,
2
)
require
.
Equal
(
t
,
float64
(
2
),
data
[
"total"
])
}
func
TestListGenerations_ListError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
listErr
=
fmt
.
Errorf
(
"db error"
)
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations"
,
""
,
1
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
func
TestListGenerations_DefaultPagination
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
h
:=
newTestSoraClientHandler
(
repo
)
// 不传分页参数,应默认 page=1 page_size=20
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations"
,
""
,
1
)
h
.
ListGenerations
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
float64
(
1
),
data
[
"page"
])
}
// ==================== Handler: GetGeneration ====================
func
TestGetGeneration_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/1"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestGetGeneration_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/abc"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestGetGeneration_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/999"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestGetGeneration_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestGetGeneration_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Model
:
"sora2-landscape-10s"
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
"https://example.com/video.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
GetGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
float64
(
1
),
data
[
"id"
])
}
// ==================== Handler: DeleteGeneration ====================
func
TestDeleteGeneration_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestDeleteGeneration_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/abc"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestDeleteGeneration_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/999"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestDeleteGeneration_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestDeleteGeneration_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
// ==================== Handler: CancelGeneration ====================
func
TestCancelGeneration_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestCancelGeneration_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/abc/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestCancelGeneration_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/999/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestCancelGeneration_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"pending"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestCancelGeneration_Pending
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"cancelled"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestCancelGeneration_Generating
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"generating"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"cancelled"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestCancelGeneration_Completed
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
func
TestCancelGeneration_Failed
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"failed"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
func
TestCancelGeneration_Cancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"cancelled"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
// ==================== Handler: GetQuota ====================
func
TestGetQuota_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
0
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestGetQuota_NilQuotaService
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
1
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"unlimited"
,
data
[
"source"
])
}
// ==================== Handler: GetModels ====================
func
TestGetModels
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/models"
,
""
,
0
)
h
.
GetModels
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
([]
any
)
require
.
Len
(
t
,
data
,
4
)
// 验证类型分布
videoCount
,
imageCount
:=
0
,
0
for
_
,
item
:=
range
data
{
m
:=
item
.
(
map
[
string
]
any
)
if
m
[
"type"
]
==
"video"
{
videoCount
++
}
else
if
m
[
"type"
]
==
"image"
{
imageCount
++
}
}
require
.
Equal
(
t
,
3
,
videoCount
)
require
.
Equal
(
t
,
1
,
imageCount
)
}
// ==================== Handler: GetStorageStatus ====================
func
TestGetStorageStatus_NilS3
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
false
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
false
,
data
[
"s3_healthy"
])
require
.
Equal
(
t
,
false
,
data
[
"local_enabled"
])
}
func
TestGetStorageStatus_LocalEnabled
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-storage-status-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
false
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
false
,
data
[
"s3_healthy"
])
require
.
Equal
(
t
,
true
,
data
[
"local_enabled"
])
}
// ==================== Handler: SaveToStorage ====================
func
TestSaveToStorage_Unauthorized
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
0
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
rec
.
Code
)
}
func
TestSaveToStorage_InvalidID
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/abc/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"abc"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestSaveToStorage_NotFound
(
t
*
testing
.
T
)
{
h
:=
newTestSoraClientHandler
(
newStubSoraGenRepo
())
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/999/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"999"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestSaveToStorage_NotUpstream
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"s3"
,
MediaURL
:
"https://example.com/v.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestSaveToStorage_EmptyMediaURL
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
""
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestSaveToStorage_S3Nil
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
"https://example.com/video.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"云存储"
)
}
func
TestSaveToStorage_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
"https://example.com/video.mp4"
}
h
:=
newTestSoraClientHandler
(
repo
)
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
// ==================== storeMediaWithDegradation — nil guard 路径 ====================
func
TestStoreMediaWithDegradation_NilS3NilMedia
(
t
*
testing
.
T
)
{
h
:=
&
SoraClientHandler
{}
url
,
urls
,
storageType
,
keys
,
size
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Equal
(
t
,
"https://upstream.com/v.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://upstream.com/v.mp4"
},
urls
)
require
.
Nil
(
t
,
keys
)
require
.
Equal
(
t
,
int64
(
0
),
size
)
}
func
TestStoreMediaWithDegradation_NilGuardsMultiURL
(
t
*
testing
.
T
)
{
h
:=
&
SoraClientHandler
{}
url
,
urls
,
storageType
,
keys
,
size
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Equal
(
t
,
"https://a.com/1.mp4"
,
url
)
require
.
Equal
(
t
,
[]
string
{
"https://a.com/1.mp4"
,
"https://a.com/2.mp4"
},
urls
)
require
.
Nil
(
t
,
keys
)
require
.
Equal
(
t
,
int64
(
0
),
size
)
}
func
TestStoreMediaWithDegradation_EmptyMediaURLsFallback
(
t
*
testing
.
T
)
{
h
:=
&
SoraClientHandler
{}
url
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
[]
string
{},
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Equal
(
t
,
"https://upstream.com/v.mp4"
,
url
)
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var
_
service
.
UserRepository
=
(
*
stubUserRepoForHandler
)(
nil
)
type
stubUserRepoForHandler
struct
{
users
map
[
int64
]
*
service
.
User
updateErr
error
}
func
newStubUserRepoForHandler
()
*
stubUserRepoForHandler
{
return
&
stubUserRepoForHandler
{
users
:
make
(
map
[
int64
]
*
service
.
User
)}
}
func
(
r
*
stubUserRepoForHandler
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
if
u
,
ok
:=
r
.
users
[
id
];
ok
{
return
u
,
nil
}
return
nil
,
fmt
.
Errorf
(
"user not found"
)
}
func
(
r
*
stubUserRepoForHandler
)
Update
(
_
context
.
Context
,
user
*
service
.
User
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
users
[
user
.
ID
]
=
user
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
Create
(
context
.
Context
,
*
service
.
User
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
GetByEmail
(
context
.
Context
,
string
)
(
*
service
.
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
GetFirstAdmin
(
context
.
Context
)
(
*
service
.
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
UpdateBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
ExistsByEmail
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubUserRepoForHandler
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForHandler
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
// ==================== NewSoraClientHandler ====================
func
TestNewSoraClientHandler
(
t
*
testing
.
T
)
{
h
:=
NewSoraClientHandler
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
require
.
NotNil
(
t
,
h
)
}
func
TestNewSoraClientHandler_WithAPIKeyService
(
t
*
testing
.
T
)
{
h
:=
NewSoraClientHandler
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
require
.
NotNil
(
t
,
h
)
require
.
Nil
(
t
,
h
.
apiKeyService
)
}
// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
var
_
service
.
APIKeyRepository
=
(
*
stubAPIKeyRepoForHandler
)(
nil
)
type
stubAPIKeyRepoForHandler
struct
{
keys
map
[
int64
]
*
service
.
APIKey
getErr
error
}
func
newStubAPIKeyRepoForHandler
()
*
stubAPIKeyRepoForHandler
{
return
&
stubAPIKeyRepoForHandler
{
keys
:
make
(
map
[
int64
]
*
service
.
APIKey
)}
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
APIKey
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
if
k
,
ok
:=
r
.
keys
[
id
];
ok
{
return
k
,
nil
}
return
nil
,
fmt
.
Errorf
(
"api key not found: %d"
,
id
)
}
func
(
r
*
stubAPIKeyRepoForHandler
)
Create
(
context
.
Context
,
*
service
.
APIKey
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetKeyAndOwnerID
(
_
context
.
Context
,
_
int64
)
(
string
,
int64
,
error
)
{
return
""
,
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetByKey
(
context
.
Context
,
string
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetByKeyForAuth
(
context
.
Context
,
string
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
Update
(
context
.
Context
,
*
service
.
APIKey
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListByUserID
(
_
context
.
Context
,
_
int64
,
_
pagination
.
PaginationParams
,
_
service
.
APIKeyListFilters
)
([]
service
.
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
VerifyOwnership
(
context
.
Context
,
int64
,
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
CountByUserID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ExistsByKey
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListByGroupID
(
_
context
.
Context
,
_
int64
,
_
pagination
.
PaginationParams
)
([]
service
.
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
SearchAPIKeys
(
context
.
Context
,
int64
,
string
,
int
)
([]
service
.
APIKey
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
UpdateGroupIDByUserAndGroup
(
_
context
.
Context
,
userID
,
oldGroupID
,
newGroupID
int64
)
(
int64
,
error
)
{
var
updated
int64
for
id
,
key
:=
range
r
.
keys
{
if
key
.
UserID
!=
userID
||
key
.
GroupID
==
nil
||
*
key
.
GroupID
!=
oldGroupID
{
continue
}
clone
:=
*
key
gid
:=
newGroupID
clone
.
GroupID
=
&
gid
r
.
keys
[
id
]
=
&
clone
updated
++
}
return
updated
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListKeysByUserID
(
context
.
Context
,
int64
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ListKeysByGroupID
(
context
.
Context
,
int64
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
IncrementQuotaUsed
(
_
context
.
Context
,
_
int64
,
_
float64
)
(
float64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
UpdateLastUsed
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
IncrementRateLimitUsage
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
ResetRateLimitWindows
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAPIKeyRepoForHandler
)
GetRateLimitData
(
context
.
Context
,
int64
)
(
*
service
.
APIKeyRateLimitData
,
error
)
{
return
nil
,
nil
}
// newTestAPIKeyService 创建测试用的 APIKeyService
func
newTestAPIKeyService
(
repo
*
stubAPIKeyRepoForHandler
)
*
service
.
APIKeyService
{
return
service
.
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{})
}
// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
func
TestGenerate_WithAPIKeyID_Success
(
t
*
testing
.
T
)
{
// 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
groupID
:=
int64
(
5
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyActive
,
GroupID
:
&
groupID
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
NotZero
(
t
,
data
[
"generation_id"
])
// 验证 api_key_id 已关联到生成记录
gen
:=
repo
.
gens
[
1
]
require
.
NotNil
(
t
,
gen
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
gen
.
APIKeyID
)
}
func
TestGenerate_WithAPIKeyID_NotFound
(
t
*
testing
.
T
)
{
// 前端传递不存在的 api_key_id → 400
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"不存在"
)
}
func
TestGenerate_WithAPIKeyID_WrongUser
(
t
*
testing
.
T
)
{
// 前端传递别人的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
999
,
// 属于 user 999
Status
:
service
.
StatusAPIKeyActive
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"不属于"
)
}
func
TestGenerate_WithAPIKeyID_Disabled
(
t
*
testing
.
T
)
{
// 前端传递已禁用的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyDisabled
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"不可用"
)
}
func
TestGenerate_WithAPIKeyID_QuotaExhausted
(
t
*
testing
.
T
)
{
// 前端传递配额耗尽的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyQuotaExhausted
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
}
func
TestGenerate_WithAPIKeyID_Expired
(
t
*
testing
.
T
)
{
// 前端传递已过期的 api_key_id → 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyExpired
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
}
func
TestGenerate_WithAPIKeyID_NilAPIKeyService
(
t
*
testing
.
T
)
{
// apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
// apiKeyService = nil
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
require
.
Nil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_WithAPIKeyID_NilGroupID
(
t
*
testing
.
T
)
{
// api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyActive
,
GroupID
:
nil
,
// 无分组
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_NoAPIKeyID_NoContext_NilResult
(
t
*
testing
.
T
)
{
// 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Nil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_WithAPIKeyIDInBody_OverridesContext
(
t
*
testing
.
T
)
{
// 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
groupID
:=
int64
(
10
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyRepo
.
keys
[
42
]
=
&
service
.
APIKey
{
ID
:
42
,
UserID
:
1
,
Status
:
service
.
StatusAPIKeyActive
,
GroupID
:
&
groupID
,
}
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`
,
1
)
c
.
Set
(
"api_key_id"
,
int64
(
99
))
// context 中有另一个 api_key_id
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// 应使用 body 中的 api_key_id=42,而不是 context 中的 99
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_WithContextAPIKeyID_FallbackPath
(
t
*
testing
.
T
)
{
// 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
c
.
Set
(
"api_key_id"
,
int64
(
99
))
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// 应使用 context 中的 api_key_id=99
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
99
),
*
repo
.
gens
[
1
]
.
APIKeyID
)
}
func
TestGenerate_APIKeyID_Zero_IgnoredInJSON
(
t
*
testing
.
T
)
{
// JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyRepo
:=
newStubAPIKeyRepoForHandler
()
apiKeyService
:=
newTestAPIKeyService
(
apiKeyRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
apiKeyService
:
apiKeyService
}
// JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
// api_key_id=0 不存在 → 400
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
func
TestProcessGeneration_WithGroupID_NoForcePlatform
(
t
*
testing
.
T
)
{
// groupID 不为 nil → 不设置 ForcePlatform
// gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
gid
:=
int64
(
5
)
h
.
processGeneration
(
1
,
1
,
&
gid
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"gatewayService"
)
}
func
TestProcessGeneration_NilGroupID_SetsForcePlatform
(
t
*
testing
.
T
)
{
// groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"gatewayService"
)
}
func
TestProcessGeneration_MarkGeneratingStateConflict
(
t
*
testing
.
T
)
{
// 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"cancelled"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
require
.
Equal
(
t
,
"cancelled"
,
repo
.
gens
[
1
]
.
Status
)
}
// ==================== GenerateRequest JSON 解析 ====================
func
TestGenerateRequest_WithAPIKeyID_JSONParsing
(
t
*
testing
.
T
)
{
// 验证 api_key_id 在 JSON 中正确解析为 *int64
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{"model":"sora2","prompt":"test","api_key_id":42}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
req
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
req
.
APIKeyID
)
}
func
TestGenerateRequest_WithoutAPIKeyID_JSONParsing
(
t
*
testing
.
T
)
{
// 不传 api_key_id → 解析后为 nil
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{"model":"sora2","prompt":"test"}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
req
.
APIKeyID
)
}
func
TestGenerateRequest_NullAPIKeyID_JSONParsing
(
t
*
testing
.
T
)
{
// api_key_id: null → 解析后为 nil
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{"model":"sora2","prompt":"test","api_key_id":null}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
req
.
APIKeyID
)
}
func
TestGenerateRequest_FullFields_JSONParsing
(
t
*
testing
.
T
)
{
// 全字段解析
var
req
GenerateRequest
err
:=
json
.
Unmarshal
([]
byte
(
`{
"model":"sora2-landscape-10s",
"prompt":"test prompt",
"media_type":"video",
"video_count":2,
"image_input":"data:image/png;base64,abc",
"api_key_id":100
}`
),
&
req
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
req
.
Model
)
require
.
Equal
(
t
,
"test prompt"
,
req
.
Prompt
)
require
.
Equal
(
t
,
"video"
,
req
.
MediaType
)
require
.
Equal
(
t
,
2
,
req
.
VideoCount
)
require
.
Equal
(
t
,
"data:image/png;base64,abc"
,
req
.
ImageInput
)
require
.
NotNil
(
t
,
req
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
100
),
*
req
.
APIKeyID
)
}
func
TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID
(
t
*
testing
.
T
)
{
// api_key_id 为 nil 时 JSON 序列化应省略
req
:=
GenerateRequest
{
Model
:
"sora2"
,
Prompt
:
"test"
}
b
,
err
:=
json
.
Marshal
(
req
)
require
.
NoError
(
t
,
err
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
b
,
&
parsed
))
_
,
hasAPIKeyID
:=
parsed
[
"api_key_id"
]
require
.
False
(
t
,
hasAPIKeyID
,
"api_key_id 为 nil 时应省略"
)
}
func
TestGenerateRequest_JSONSerialize_IncludesAPIKeyID
(
t
*
testing
.
T
)
{
// api_key_id 不为 nil 时 JSON 序列化应包含
id
:=
int64
(
42
)
req
:=
GenerateRequest
{
Model
:
"sora2"
,
Prompt
:
"test"
,
APIKeyID
:
&
id
}
b
,
err
:=
json
.
Marshal
(
req
)
require
.
NoError
(
t
,
err
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
b
,
&
parsed
))
require
.
Equal
(
t
,
float64
(
42
),
parsed
[
"api_key_id"
])
}
// ==================== GetQuota: 有配额服务 ====================
func
TestGetQuota_WithQuotaService_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
3
*
1024
*
1024
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
1
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"user"
,
data
[
"source"
])
require
.
Equal
(
t
,
float64
(
10
*
1024
*
1024
),
data
[
"quota_bytes"
])
require
.
Equal
(
t
,
float64
(
3
*
1024
*
1024
),
data
[
"used_bytes"
])
}
func
TestGetQuota_WithQuotaService_Error
(
t
*
testing
.
T
)
{
// 用户不存在时 GetQuota 返回错误
userRepo
:=
newStubUserRepoForHandler
()
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/quota"
,
""
,
999
)
h
.
GetQuota
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== Generate: 配额检查 ====================
func
TestGenerate_QuotaCheckFailed
(
t
*
testing
.
T
)
{
// 配额超限时返回 429
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
1024
,
SoraStorageUsedBytes
:
1025
,
// 已超限
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
}
func
TestGenerate_QuotaCheckPassed
(
t
*
testing
.
T
)
{
// 配额充足时允许生成
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
0
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
quotaService
:
quotaService
,
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
`{"model":"sora2-landscape-10s","prompt":"test"}`
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
var
_
service
.
SettingRepository
=
(
*
stubSettingRepoForHandler
)(
nil
)
type
stubSettingRepoForHandler
struct
{
values
map
[
string
]
string
}
func
newStubSettingRepoForHandler
(
values
map
[
string
]
string
)
*
stubSettingRepoForHandler
{
if
values
==
nil
{
values
=
make
(
map
[
string
]
string
)
}
return
&
stubSettingRepoForHandler
{
values
:
values
}
}
func
(
r
*
stubSettingRepoForHandler
)
Get
(
_
context
.
Context
,
key
string
)
(
*
service
.
Setting
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
&
service
.
Setting
{
Key
:
key
,
Value
:
v
},
nil
}
return
nil
,
service
.
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForHandler
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
v
,
nil
}
return
""
,
service
.
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForHandler
)
Set
(
_
context
.
Context
,
key
,
value
string
)
error
{
r
.
values
[
key
]
=
value
return
nil
}
func
(
r
*
stubSettingRepoForHandler
)
GetMultiple
(
_
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
result
:=
make
(
map
[
string
]
string
)
for
_
,
k
:=
range
keys
{
if
v
,
ok
:=
r
.
values
[
k
];
ok
{
result
[
k
]
=
v
}
}
return
result
,
nil
}
func
(
r
*
stubSettingRepoForHandler
)
SetMultiple
(
_
context
.
Context
,
settings
map
[
string
]
string
)
error
{
for
k
,
v
:=
range
settings
{
r
.
values
[
k
]
=
v
}
return
nil
}
func
(
r
*
stubSettingRepoForHandler
)
GetAll
(
_
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
return
r
.
values
,
nil
}
func
(
r
*
stubSettingRepoForHandler
)
Delete
(
_
context
.
Context
,
key
string
)
error
{
delete
(
r
.
values
,
key
)
return
nil
}
// ==================== S3 / MediaStorage 辅助函数 ====================
// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
func
newS3StorageForHandler
(
endpoint
string
)
*
service
.
SoraS3Storage
{
settingRepo
:=
newStubSettingRepoForHandler
(
map
[
string
]
string
{
"sora_s3_enabled"
:
"true"
,
"sora_s3_endpoint"
:
endpoint
,
"sora_s3_region"
:
"us-east-1"
,
"sora_s3_bucket"
:
"test-bucket"
,
"sora_s3_access_key_id"
:
"AKIATEST"
,
"sora_s3_secret_access_key"
:
"test-secret"
,
"sora_s3_prefix"
:
"sora"
,
"sora_s3_force_path_style"
:
"true"
,
})
settingService
:=
service
.
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
return
service
.
NewSoraS3Storage
(
settingService
)
}
// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
func
newFakeSourceServer
()
*
httptest
.
Server
{
return
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"video/mp4"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"fake video data for test"
))
}))
}
// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
func
newFakeS3Server
(
mode
string
)
*
httptest
.
Server
{
var
counter
atomic
.
Int32
return
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
,
_
=
io
.
Copy
(
io
.
Discard
,
r
.
Body
)
_
=
r
.
Body
.
Close
()
switch
mode
{
case
"ok"
:
w
.
Header
()
.
Set
(
"ETag"
,
`"test-etag"`
)
w
.
WriteHeader
(
http
.
StatusOK
)
case
"fail"
:
w
.
WriteHeader
(
http
.
StatusForbidden
)
_
,
_
=
w
.
Write
([]
byte
(
`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`
))
case
"fail-second"
:
n
:=
counter
.
Add
(
1
)
if
n
<=
1
{
w
.
Header
()
.
Set
(
"ETag"
,
`"test-etag"`
)
w
.
WriteHeader
(
http
.
StatusOK
)
}
else
{
w
.
WriteHeader
(
http
.
StatusForbidden
)
_
,
_
=
w
.
Write
([]
byte
(
`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`
))
}
}
}))
}
// ==================== processGeneration 直接调用测试 ====================
func
TestProcessGeneration_MarkGeneratingFails
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
repo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
// 直接调用(非 goroutine),MarkGenerating 失败 → 早退
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
// repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
// 因此 ErrorMessage 为空(证明未调用 MarkFailed)
require
.
Equal
(
t
,
"generating"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Empty
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
)
}
func
TestProcessGeneration_GatewayServiceNil
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
// gatewayService 未设置 → MarkFailed
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"gatewayService"
)
}
// ==================== storeMediaWithDegradation: S3 路径 ====================
func
TestStoreMediaWithDegradation_S3SuccessSingleURL
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
storageType
)
require
.
Len
(
t
,
s3Keys
,
1
)
require
.
NotEmpty
(
t
,
s3Keys
[
0
])
require
.
Len
(
t
,
storedURLs
,
1
)
require
.
Equal
(
t
,
storedURL
,
storedURLs
[
0
])
require
.
Contains
(
t
,
storedURL
,
fakeS3
.
URL
)
require
.
Contains
(
t
,
storedURL
,
"/test-bucket/"
)
require
.
Greater
(
t
,
fileSize
,
int64
(
0
))
}
func
TestStoreMediaWithDegradation_S3SuccessMultiURL
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
urls
:=
[]
string
{
sourceServer
.
URL
+
"/a.mp4"
,
sourceServer
.
URL
+
"/b.mp4"
}
storedURL
,
storedURLs
,
storageType
,
s3Keys
,
fileSize
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/a.mp4"
,
urls
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
storageType
)
require
.
Len
(
t
,
s3Keys
,
2
)
require
.
Len
(
t
,
storedURLs
,
2
)
require
.
Equal
(
t
,
storedURL
,
storedURLs
[
0
])
require
.
Contains
(
t
,
storedURLs
[
0
],
fakeS3
.
URL
)
require
.
Contains
(
t
,
storedURLs
[
1
],
fakeS3
.
URL
)
require
.
Greater
(
t
,
fileSize
,
int64
(
0
))
}
func
TestStoreMediaWithDegradation_S3DownloadFails
(
t
*
testing
.
T
)
{
// 上游返回 404 → 下载失败 → S3 上传不会开始
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
badSource
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusNotFound
)
}))
defer
badSource
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
_
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
badSource
.
URL
+
"/missing.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
}
func
TestStoreMediaWithDegradation_S3FailsSingleURL
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
_
,
_
,
storageType
,
s3Keys
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
// S3 失败,降级到 upstream
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Nil
(
t
,
s3Keys
)
}
func
TestStoreMediaWithDegradation_S3PartialFailureCleanup
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail-second"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
urls
:=
[]
string
{
sourceServer
.
URL
+
"/a.mp4"
,
sourceServer
.
URL
+
"/b.mp4"
}
_
,
_
,
storageType
,
s3Keys
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/a.mp4"
,
urls
,
)
// 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
require
.
Nil
(
t
,
s3Keys
)
}
// ==================== storeMediaWithDegradation: 本地存储路径 ====================
func
TestStoreMediaWithDegradation_LocalStorageFails
(
t
*
testing
.
T
)
{
// 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
"/dev/null/invalid_dir"
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
_
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
"https://upstream.com/v.mp4"
,
nil
,
)
// 本地存储失败,降级到 upstream
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
storageType
)
}
func
TestStoreMediaWithDegradation_LocalStorageSuccess
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-handler-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
DownloadTimeoutSeconds
:
5
,
MaxDownloadBytes
:
10
*
1024
*
1024
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
_
,
_
,
storageType
,
s3Keys
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeLocal
,
storageType
)
require
.
Nil
(
t
,
s3Keys
)
// 本地存储不返回 S3 keys
}
func
TestStoreMediaWithDegradation_S3FailsFallbackToLocal
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-handler-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
DownloadTimeoutSeconds
:
5
,
MaxDownloadBytes
:
10
*
1024
*
1024
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
,
mediaStorage
:
mediaStorage
,
}
_
,
_
,
storageType
,
_
,
_
:=
h
.
storeMediaWithDegradation
(
context
.
Background
(),
1
,
"video"
,
sourceServer
.
URL
+
"/v.mp4"
,
nil
,
)
// S3 失败 → 本地存储成功
require
.
Equal
(
t
,
service
.
SoraStorageTypeLocal
,
storageType
)
}
// ==================== SaveToStorage: S3 路径 ====================
func
TestSaveToStorage_S3EnabledButUploadFails
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
resp
[
"message"
],
"S3"
)
}
func
TestSaveToStorage_UpstreamURLExpired
(
t
*
testing
.
T
)
{
expiredServer
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusForbidden
)
}))
defer
expiredServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
expiredServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusGone
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
fmt
.
Sprint
(
resp
[
"message"
]),
"过期"
)
}
func
TestSaveToStorage_S3EnabledUploadSuccess
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Contains
(
t
,
data
[
"message"
],
"S3"
)
require
.
NotEmpty
(
t
,
data
[
"object_key"
])
// 验证记录已更新为 S3 存储
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
repo
.
gens
[
1
]
.
StorageType
)
}
func
TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v1.mp4"
,
MediaURLs
:
[]
string
{
sourceServer
.
URL
+
"/v1.mp4"
,
sourceServer
.
URL
+
"/v2.mp4"
,
},
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Len
(
t
,
data
[
"object_keys"
]
.
([]
any
),
2
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
repo
.
gens
[
1
]
.
StorageType
)
require
.
Len
(
t
,
repo
.
gens
[
1
]
.
S3ObjectKeys
,
2
)
require
.
Len
(
t
,
repo
.
gens
[
1
]
.
MediaURLs
,
2
)
}
func
TestSaveToStorage_S3EnabledUploadSuccessWithQuota
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
100
*
1024
*
1024
,
SoraStorageUsedBytes
:
0
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
// 验证配额已累加
require
.
Greater
(
t
,
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
,
int64
(
0
))
}
func
TestSaveToStorage_S3UploadSuccessMarkCompletedFails
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== GetStorageStatus: S3 路径 ====================
func
TestGetStorageStatus_S3EnabledNotHealthy
(
t
*
testing
.
T
)
{
// S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
true
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
false
,
data
[
"s3_healthy"
])
}
func
TestGetStorageStatus_S3EnabledHealthy
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"GET"
,
"/api/v1/sora/storage-status"
,
""
,
0
)
h
.
GetStorageStatus
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
true
,
data
[
"s3_enabled"
])
require
.
Equal
(
t
,
true
,
data
[
"s3_healthy"
])
}
// ==================== Stub: AccountRepository (用于 GatewayService) ====================
var
_
service
.
AccountRepository
=
(
*
stubAccountRepoForHandler
)(
nil
)
type
stubAccountRepoForHandler
struct
{
accounts
[]
service
.
Account
}
func
(
r
*
stubAccountRepoForHandler
)
Create
(
context
.
Context
,
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
for
i
:=
range
r
.
accounts
{
if
r
.
accounts
[
i
]
.
ID
==
id
{
return
&
r
.
accounts
[
i
],
nil
}
}
return
nil
,
fmt
.
Errorf
(
"account not found"
)
}
func
(
r
*
stubAccountRepoForHandler
)
GetByIDs
(
context
.
Context
,
[]
int64
)
([]
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ExistsByID
(
context
.
Context
,
int64
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
GetByCRSAccountID
(
context
.
Context
,
string
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
FindByExtraField
(
context
.
Context
,
string
,
any
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListCRSAccountIDs
(
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
Update
(
context
.
Context
,
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
string
,
string
,
string
,
string
,
int64
,
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListByGroup
(
context
.
Context
,
int64
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListActive
(
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListByPlatform
(
context
.
Context
,
string
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
UpdateLastUsed
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
BatchUpdateLastUsed
(
context
.
Context
,
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetError
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearError
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetSchedulable
(
context
.
Context
,
int64
,
bool
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
AutoPauseExpiredAccounts
(
context
.
Context
,
time
.
Time
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
BindGroups
(
context
.
Context
,
int64
,
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulable
(
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByGroupID
(
context
.
Context
,
int64
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByPlatform
(
_
context
.
Context
,
_
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByGroupIDAndPlatform
(
context
.
Context
,
int64
,
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByPlatforms
(
context
.
Context
,
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableByGroupIDAndPlatforms
(
context
.
Context
,
int64
,
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableUngroupedByPlatform
(
_
context
.
Context
,
_
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ListSchedulableUngroupedByPlatforms
(
_
context
.
Context
,
_
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
accounts
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetRateLimited
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetModelRateLimit
(
context
.
Context
,
int64
,
string
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetOverloaded
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
SetTempUnschedulable
(
context
.
Context
,
int64
,
time
.
Time
,
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearTempUnschedulable
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearRateLimit
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearAntigravityQuotaScopes
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ClearModelRateLimits
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
UpdateSessionWindow
(
context
.
Context
,
int64
,
*
time
.
Time
,
*
time
.
Time
,
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
UpdateExtra
(
context
.
Context
,
int64
,
map
[
string
]
any
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
BulkUpdate
(
context
.
Context
,
[]
int64
,
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepoForHandler
)
IncrementQuotaUsed
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepoForHandler
)
ResetQuotaUsed
(
context
.
Context
,
int64
)
error
{
return
nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var
_
service
.
SoraClient
=
(
*
stubSoraClientForHandler
)(
nil
)
type
stubSoraClientForHandler
struct
{
videoStatus
*
service
.
SoraVideoTaskStatus
}
func
(
s
*
stubSoraClientForHandler
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClientForHandler
)
UploadImage
(
context
.
Context
,
*
service
.
Account
,
[]
byte
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
CreateImageTask
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
CreateVideoTask
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraVideoRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
CreateStoryboardTask
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraStoryboardRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
UploadCharacterVideo
(
context
.
Context
,
*
service
.
Account
,
[]
byte
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetCameoStatus
(
context
.
Context
,
*
service
.
Account
,
string
)
(
*
service
.
SoraCameoStatus
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
DownloadCharacterImage
(
context
.
Context
,
*
service
.
Account
,
string
)
([]
byte
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
UploadCharacterImage
(
context
.
Context
,
*
service
.
Account
,
[]
byte
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
FinalizeCharacter
(
context
.
Context
,
*
service
.
Account
,
service
.
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
SetCharacterPublic
(
context
.
Context
,
*
service
.
Account
,
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForHandler
)
DeleteCharacter
(
context
.
Context
,
*
service
.
Account
,
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForHandler
)
PostVideoForWatermarkFree
(
context
.
Context
,
*
service
.
Account
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
DeletePost
(
context
.
Context
,
*
service
.
Account
,
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetWatermarkFreeURLCustom
(
context
.
Context
,
*
service
.
Account
,
string
,
string
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
EnhancePrompt
(
context
.
Context
,
*
service
.
Account
,
string
,
string
,
int
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetImageTask
(
context
.
Context
,
*
service
.
Account
,
string
)
(
*
service
.
SoraImageTaskStatus
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubSoraClientForHandler
)
GetVideoTask
(
_
context
.
Context
,
_
*
service
.
Account
,
_
string
)
(
*
service
.
SoraVideoTaskStatus
,
error
)
{
return
s
.
videoStatus
,
nil
}
// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
func
newMinimalGatewayService
(
accountRepo
service
.
AccountRepository
)
*
service
.
GatewayService
{
return
service
.
NewGatewayService
(
accountRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
}
// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
func
newMinimalSoraGatewayService
(
soraClient
service
.
SoraClient
)
*
service
.
SoraGatewayService
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
return
service
.
NewSoraGatewayService
(
soraClient
,
nil
,
nil
,
cfg
)
}
// ==================== processGeneration: 更多路径测试 ====================
func
TestProcessGeneration_SelectAccountError
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
nil
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"选择账号失败"
)
}
func
TestProcessGeneration_SoraGatewayServiceNil
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 提供可用账号使 SelectAccountForModel 成功
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
// soraGatewayService 为 nil
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"soraGatewayService"
)
}
func
TestProcessGeneration_ForwardError
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
// SoraClient 返回视频任务失败
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"content policy violation"
,
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"生成失败"
)
}
func
TestProcessGeneration_ForwardErrorCancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
// MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
// 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
repo
.
getByIDOverrideAfterN
=
1
repo
.
getByIDOverrideStatus
=
"cancelled"
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"reject"
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
require
.
Equal
(
t
,
"generating"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestProcessGeneration_ForwardSuccessNoMediaURL
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
// SoraClient 返回 completed 但无 URL
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
nil
,
// 无 URL
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"未获取到媒体 URL"
)
}
func
TestProcessGeneration_ForwardSuccessCancelledBeforeStore
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
// MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
// 第 2 次返回 "cancelled" 状态,模拟外部取消
repo
.
getByIDOverrideAfterN
=
1
repo
.
getByIDOverrideStatus
=
"cancelled"
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
// Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
require
.
Equal
(
t
,
"generating"
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestProcessGeneration_FullSuccessUpstream
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
// 无 S3 和本地存储,降级到 upstream
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"completed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeUpstream
,
repo
.
gens
[
1
]
.
StorageType
)
require
.
NotEmpty
(
t
,
repo
.
gens
[
1
]
.
MediaURL
)
}
func
TestProcessGeneration_FullSuccessWithS3
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
sourceServer
.
URL
+
"/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
100
*
1024
*
1024
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"completed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
service
.
SoraStorageTypeS3
,
repo
.
gens
[
1
]
.
StorageType
)
require
.
NotEmpty
(
t
,
repo
.
gens
[
1
]
.
S3ObjectKeys
)
require
.
Greater
(
t
,
repo
.
gens
[
1
]
.
FileSizeBytes
,
int64
(
0
))
// 验证配额已累加
require
.
Greater
(
t
,
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
,
int64
(
0
))
}
func
TestProcessGeneration_MarkCompletedFails
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复"
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
repo
.
updateCallCount
=
new
(
int32
)
repo
.
updateFailAfterN
=
1
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
soraClient
:=
&
stubSoraClientForHandler
{
videoStatus
:
&
service
.
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/video.mp4"
},
},
}
soraGatewayService
:=
newMinimalSoraGatewayService
(
soraClient
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test prompt"
,
"video"
,
""
,
1
)
// MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
// 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
// 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
require
.
Equal
(
t
,
"completed"
,
repo
.
gens
[
1
]
.
Status
)
}
// ==================== cleanupStoredMedia 直接测试 ====================
func
TestCleanupStoredMedia_S3Path
(
t
*
testing
.
T
)
{
// S3 清理路径:s3Storage 为 nil 时不 panic
h
:=
&
SoraClientHandler
{}
// 不应 panic
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
[]
string
{
"key1"
},
nil
)
}
func
TestCleanupStoredMedia_LocalPath
(
t
*
testing
.
T
)
{
// 本地清理路径:mediaStorage 为 nil 时不 panic
h
:=
&
SoraClientHandler
{}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeLocal
,
nil
,
[]
string
{
"/tmp/test.mp4"
})
}
func
TestCleanupStoredMedia_UpstreamPath
(
t
*
testing
.
T
)
{
// upstream 类型不清理
h
:=
&
SoraClientHandler
{}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeUpstream
,
nil
,
nil
)
}
func
TestCleanupStoredMedia_EmptyKeys
(
t
*
testing
.
T
)
{
// 空 keys 不触发清理
h
:=
&
SoraClientHandler
{}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
nil
,
nil
)
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeLocal
,
nil
,
nil
)
}
// ==================== DeleteGeneration: 本地存储清理路径 ====================
func
TestDeleteGeneration_LocalStorageCleanup
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-delete-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeLocal
,
MediaURL
:
"video/test.mp4"
,
MediaURLs
:
[]
string
{
"video/test.mp4"
},
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback
(
t
*
testing
.
T
)
{
// MediaURLs 为空,使用 MediaURL 作为清理路径
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-delete-fallback-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeLocal
,
MediaURL
:
"video/test.mp4"
,
MediaURLs
:
nil
,
// 空
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestDeleteGeneration_NonLocalStorage_SkipCleanup
(
t
*
testing
.
T
)
{
// 非本地存储类型 → 跳过清理
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeUpstream
,
MediaURL
:
"https://upstream.com/v.mp4"
,
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestDeleteGeneration_DeleteError
(
t
*
testing
.
T
)
{
// repo.Delete 出错
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
}
repo
.
deleteErr
=
fmt
.
Errorf
(
"delete failed"
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
// ==================== fetchUpstreamModels 测试 ====================
func
TestFetchUpstreamModels_NilGateway
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
h
:=
&
SoraClientHandler
{}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"gatewayService 未初始化"
)
}
func
TestFetchUpstreamModels_NoAccounts
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
nil
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"选择 Sora 账号失败"
)
}
func
TestFetchUpstreamModels_NonAPIKeyAccount
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
"oauth"
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"不支持模型同步"
)
}
func
TestFetchUpstreamModels_MissingAPIKey
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://sora.test"
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"api_key"
)
}
func
TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
}
func
TestFetchUpstreamModels_UpstreamReturns500
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"状态码 500"
)
}
func
TestFetchUpstreamModels_UpstreamReturnsInvalidJSON
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"not json"
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"解析响应失败"
)
}
func
TestFetchUpstreamModels_UpstreamReturnsEmptyList
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"空模型列表"
)
}
func
TestFetchUpstreamModels_Success
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
// 验证请求头
require
.
Equal
(
t
,
"Bearer sk-test"
,
r
.
Header
.
Get
(
"Authorization"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
r
.
URL
.
Path
,
"/sora/v1/models"
))
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
families
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
families
)
}
func
TestFetchUpstreamModels_UnrecognizedModels
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
_
,
err
:=
h
.
fetchUpstreamModels
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"未能从上游模型列表中识别"
)
}
// ==================== getModelFamilies 缓存测试 ====================
func
TestGetModelFamilies_CachesLocalConfig
(
t
*
testing
.
T
)
{
// gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
h
:=
&
SoraClientHandler
{}
families
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
NotEmpty
(
t
,
families
)
// 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
families2
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
Equal
(
t
,
families
,
families2
)
require
.
False
(
t
,
h
.
modelCacheUpstream
)
}
func
TestGetModelFamilies_CachesUpstreamResult
(
t
*
testing
.
T
)
{
t
.
Skip
(
"TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复"
)
ts
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`
))
}))
defer
ts
.
Close
()
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
[]
service
.
Account
{
{
ID
:
1
,
Type
:
service
.
AccountTypeAPIKey
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
ts
.
URL
}},
},
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
gatewayService
:
gatewayService
}
families
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
NotEmpty
(
t
,
families
)
require
.
True
(
t
,
h
.
modelCacheUpstream
)
// 第二次调用命中缓存
families2
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
Equal
(
t
,
families
,
families2
)
}
func
TestGetModelFamilies_ExpiredCacheRefreshes
(
t
*
testing
.
T
)
{
// 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
h
:=
&
SoraClientHandler
{
cachedFamilies
:
[]
service
.
SoraModelFamily
{{
ID
:
"old"
}},
modelCacheTime
:
time
.
Now
()
.
Add
(
-
10
*
time
.
Minute
),
// 已过期
modelCacheUpstream
:
false
,
}
// gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
families
:=
h
.
getModelFamilies
(
context
.
Background
())
require
.
NotEmpty
(
t
,
families
)
// 缓存已刷新,不再是 "old"
found
:=
false
for
_
,
f
:=
range
families
{
if
f
.
ID
==
"old"
{
found
=
true
}
}
require
.
False
(
t
,
found
,
"过期缓存应被刷新"
)
}
// ==================== processGeneration: groupID 与 ForcePlatform ====================
func
TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails
(
t
*
testing
.
T
)
{
// groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"pending"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 空账号列表 → SelectAccountForModel 失败
accountRepo
:=
&
stubAccountRepoForHandler
{
accounts
:
nil
}
gatewayService
:=
newMinimalGatewayService
(
accountRepo
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
gatewayService
:
gatewayService
,
}
h
.
processGeneration
(
1
,
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
,
""
,
1
)
require
.
Equal
(
t
,
"failed"
,
repo
.
gens
[
1
]
.
Status
)
require
.
Contains
(
t
,
repo
.
gens
[
1
]
.
ErrorMessage
,
"选择账号失败"
)
}
// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
func
TestGenerate_CheckQuotaNonQuotaError
(
t
*
testing
.
T
)
{
// quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
repo
:=
newStubSoraGenRepo
()
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
userRepo
:=
newStubUserRepoForHandler
()
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
NewSoraClientHandler
(
genService
,
quotaService
,
nil
,
nil
,
nil
,
nil
,
nil
)
body
:=
`{"model":"sora2-landscape-10s","prompt":"test"}`
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
body
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
rec
.
Code
)
}
// ==================== Generate: CreatePending 并发限制错误 ====================
// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
type
stubSoraGenRepoWithAtomicCreate
struct
{
stubSoraGenRepo
limitErr
error
}
func
(
r
*
stubSoraGenRepoWithAtomicCreate
)
CreatePendingWithLimit
(
_
context
.
Context
,
gen
*
service
.
SoraGeneration
,
_
[]
string
,
_
int64
)
error
{
if
r
.
limitErr
!=
nil
{
return
r
.
limitErr
}
return
r
.
stubSoraGenRepo
.
Create
(
context
.
Background
(),
gen
)
}
func
TestGenerate_CreatePendingConcurrencyLimit
(
t
*
testing
.
T
)
{
repo
:=
&
stubSoraGenRepoWithAtomicCreate
{
stubSoraGenRepo
:
*
newStubSoraGenRepo
(),
limitErr
:
service
.
ErrSoraGenerationConcurrencyLimit
,
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
NewSoraClientHandler
(
genService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
body
:=
`{"model":"sora2-landscape-10s","prompt":"test"}`
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generate"
,
body
,
1
)
h
.
Generate
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
resp
[
"message"
],
"3"
)
}
// ==================== SaveToStorage: 配额超限 ====================
func
TestSaveToStorage_QuotaExceeded
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 用户配额已满
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
,
SoraStorageUsedBytes
:
10
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
}
// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
func
TestSaveToStorage_QuotaNonQuotaError
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo
:=
newStubUserRepoForHandler
()
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== SaveToStorage: MediaURLs 全为空 ====================
func
TestSaveToStorage_EmptyMediaURLs
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
""
,
MediaURLs
:
[]
string
{},
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
resp
:=
parseResponse
(
t
,
rec
)
require
.
Contains
(
t
,
resp
[
"message"
],
"已过期"
)
}
// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
func
TestSaveToStorage_MultiURL_SecondUploadFails
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"fail-second"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v1.mp4"
,
MediaURLs
:
[]
string
{
sourceServer
.
URL
+
"/v1.mp4"
,
sourceServer
.
URL
+
"/v2.mp4"
},
}
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
func
TestSaveToStorage_MarkCompletedFailsWithQuotaRollback
(
t
*
testing
.
T
)
{
sourceServer
:=
newFakeSourceServer
()
defer
sourceServer
.
Close
()
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
"upstream"
,
MediaURL
:
sourceServer
.
URL
+
"/v.mp4"
,
}
repo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
userRepo
:=
newStubUserRepoForHandler
()
userRepo
.
users
[
1
]
=
&
service
.
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
100
*
1024
*
1024
,
SoraStorageUsedBytes
:
0
,
}
quotaService
:=
service
.
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/save"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
SaveToStorage
(
c
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
func
TestCleanupStoredMedia_WithS3Storage_ActualDelete
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"ok"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
[]
string
{
"key1"
,
"key2"
},
nil
)
}
func
TestCleanupStoredMedia_S3DeleteFails_LogOnly
(
t
*
testing
.
T
)
{
fakeS3
:=
newFakeS3Server
(
"fail"
)
defer
fakeS3
.
Close
()
s3Storage
:=
newS3StorageForHandler
(
fakeS3
.
URL
)
h
:=
&
SoraClientHandler
{
s3Storage
:
s3Storage
}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeS3
,
[]
string
{
"key1"
},
nil
)
}
func
TestCleanupStoredMedia_LocalDeleteFails_LogOnly
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-cleanup-fail-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
h
:=
&
SoraClientHandler
{
mediaStorage
:
mediaStorage
}
h
.
cleanupStoredMedia
(
context
.
Background
(),
service
.
SoraStorageTypeLocal
,
nil
,
[]
string
{
"nonexistent/file.mp4"
})
}
// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
func
TestDeleteGeneration_LocalStorageDeleteFails_LogOnly
(
t
*
testing
.
T
)
{
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"sora-del-test-*"
)
require
.
NoError
(
t
,
err
)
defer
os
.
RemoveAll
(
tmpDir
)
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
},
},
}
mediaStorage
:=
service
.
NewSoraMediaStorage
(
cfg
)
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
,
StorageType
:
service
.
SoraStorageTypeLocal
,
MediaURL
:
"nonexistent/video.mp4"
,
MediaURLs
:
[]
string
{
"nonexistent/video.mp4"
},
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
,
mediaStorage
:
mediaStorage
}
c
,
rec
:=
makeGinContext
(
"DELETE"
,
"/api/v1/sora/generations/1"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
DeleteGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
// ==================== CancelGeneration: 任务已结束冲突 ====================
func
TestCancelGeneration_AlreadyCompleted
(
t
*
testing
.
T
)
{
repo
:=
newStubSoraGenRepo
()
repo
.
gens
[
1
]
=
&
service
.
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
"completed"
}
genService
:=
service
.
NewSoraGenerationService
(
repo
,
nil
,
nil
)
h
:=
&
SoraClientHandler
{
genService
:
genService
}
c
,
rec
:=
makeGinContext
(
"POST"
,
"/api/v1/sora/generations/1/cancel"
,
""
,
1
)
c
.
Params
=
gin
.
Params
{{
Key
:
"id"
,
Value
:
"1"
}}
h
.
CancelGeneration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
}
backend/internal/handler/sora_gateway_handler.go
deleted
100644 → 0
View file @
dbb248df
package
handler
import
(
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type
SoraGatewayHandler
struct
{
gatewayService
*
service
.
GatewayService
soraGatewayService
*
service
.
SoraGatewayService
billingCacheService
*
service
.
BillingCacheService
usageRecordWorkerPool
*
service
.
UsageRecordWorkerPool
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
streamMode
string
soraTLSEnabled
bool
soraMediaSigningKey
string
soraMediaRoot
string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func
NewSoraGatewayHandler
(
gatewayService
*
service
.
GatewayService
,
soraGatewayService
*
service
.
SoraGatewayService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
usageRecordWorkerPool
*
service
.
UsageRecordWorkerPool
,
cfg
*
config
.
Config
,
)
*
SoraGatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
maxAccountSwitches
:=
3
streamMode
:=
"force"
soraTLSEnabled
:=
true
signKey
:=
""
mediaRoot
:=
"/app/data/sora"
if
cfg
!=
nil
{
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
if
cfg
.
Gateway
.
MaxAccountSwitches
>
0
{
maxAccountSwitches
=
cfg
.
Gateway
.
MaxAccountSwitches
}
if
mode
:=
strings
.
TrimSpace
(
cfg
.
Gateway
.
SoraStreamMode
);
mode
!=
""
{
streamMode
=
mode
}
soraTLSEnabled
=
!
cfg
.
Sora
.
Client
.
DisableTLSFingerprint
signKey
=
strings
.
TrimSpace
(
cfg
.
Gateway
.
SoraMediaSigningKey
)
if
root
:=
strings
.
TrimSpace
(
cfg
.
Sora
.
Storage
.
LocalPath
);
root
!=
""
{
mediaRoot
=
root
}
}
return
&
SoraGatewayHandler
{
gatewayService
:
gatewayService
,
soraGatewayService
:
soraGatewayService
,
billingCacheService
:
billingCacheService
,
usageRecordWorkerPool
:
usageRecordWorkerPool
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
streamMode
:
strings
.
ToLower
(
streamMode
),
soraTLSEnabled
:
soraTLSEnabled
,
soraMediaSigningKey
:
signKey
,
soraMediaRoot
:
mediaRoot
,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func
(
h
*
SoraGatewayHandler
)
ChatCompletions
(
c
*
gin
.
Context
)
{
apiKey
,
ok
:=
middleware2
.
GetAPIKeyFromContext
(
c
)
if
!
ok
{
h
.
errorResponse
(
c
,
http
.
StatusUnauthorized
,
"authentication_error"
,
"Invalid API key"
)
return
}
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"User context not found"
)
return
}
reqLog
:=
requestLogger
(
c
,
"handler.sora_gateway.chat_completions"
,
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
)
body
,
err
:=
pkghttputil
.
ReadRequestBodyWithPrealloc
(
c
.
Request
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to read request body"
)
return
}
if
len
(
body
)
==
0
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Request body is empty"
)
return
}
setOpsRequestContext
(
c
,
""
,
false
,
body
)
// 校验请求体 JSON 合法性
if
!
gjson
.
ValidBytes
(
body
)
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult
:=
gjson
.
GetBytes
(
body
,
"model"
)
if
!
modelResult
.
Exists
()
||
modelResult
.
Type
!=
gjson
.
String
||
modelResult
.
String
()
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
return
}
reqModel
:=
modelResult
.
String
()
msgsResult
:=
gjson
.
GetBytes
(
body
,
"messages"
)
if
!
msgsResult
.
IsArray
()
||
len
(
msgsResult
.
Array
())
==
0
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"messages is required"
)
return
}
clientStream
:=
gjson
.
GetBytes
(
body
,
"stream"
)
.
Bool
()
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
reqModel
),
zap
.
Bool
(
"stream"
,
clientStream
))
if
!
clientStream
{
if
h
.
streamMode
==
"error"
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Sora requires stream=true"
)
return
}
var
err
error
body
,
err
=
sjson
.
SetBytes
(
body
,
"stream"
,
true
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to process request"
)
return
}
}
setOpsRequestContext
(
c
,
reqModel
,
clientStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
clientStream
,
false
)))
platform
:=
""
if
forced
,
ok
:=
middleware2
.
GetForcePlatformFromContext
(
c
);
ok
{
platform
=
forced
}
else
if
apiKey
.
Group
!=
nil
{
platform
=
apiKey
.
Group
.
Platform
}
if
platform
!=
service
.
PlatformSora
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"This endpoint only supports Sora platform"
)
return
}
streamStarted
:=
false
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
maxWait
:=
service
.
CalculateMaxWait
(
subject
.
Concurrency
)
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
,
maxWait
)
waitCounted
:=
false
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.user_wait_counter_increment_failed"
,
zap
.
Error
(
err
))
}
else
if
!
canWait
{
reqLog
.
Info
(
"sora.user_wait_queue_full"
,
zap
.
Int
(
"max_wait"
,
maxWait
))
h
.
errorResponse
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
)
return
}
if
err
==
nil
&&
canWait
{
waitCounted
=
true
}
defer
func
()
{
if
waitCounted
{
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
}
}()
userReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireUserSlotWithWait
(
c
,
subject
.
UserID
,
subject
.
Concurrency
,
clientStream
,
&
streamStarted
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.user_slot_acquire_failed"
,
zap
.
Error
(
err
))
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
return
}
if
waitCounted
{
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
waitCounted
=
false
}
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
defer
userReleaseFunc
()
}
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
reqLog
.
Info
(
"sora.billing_eligibility_check_failed"
,
zap
.
Error
(
err
))
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
sessionHash
:=
generateOpenAISessionHash
(
c
,
body
)
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
var
lastFailoverBody
[]
byte
var
lastFailoverHeaders
http
.
Header
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
,
""
,
int64
(
0
))
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_select_failed"
,
zap
.
Error
(
err
),
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)),
)
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int
(
"last_upstream_status"
,
lastFailoverStatus
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.failover_exhausted_no_available_accounts"
,
fields
...
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
lastFailoverHeaders
,
lastFailoverBody
,
streamStarted
)
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
proxyBound
:=
account
.
ProxyID
!=
nil
proxyID
:=
int64
(
0
)
if
account
.
ProxyID
!=
nil
{
proxyID
=
*
account
.
ProxyID
}
tlsFingerprintEnabled
:=
h
.
soraTLSEnabled
accountReleaseFunc
:=
selection
.
ReleaseFunc
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_wait_counter_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
}
else
if
!
canWait
{
reqLog
.
Info
(
"sora.account_wait_queue_full"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"max_waiting"
,
selection
.
WaitPlan
.
MaxWaiting
),
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}()
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
clientStream
,
&
streamStarted
,
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_slot_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
result
,
err
:=
h
.
soraGatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
,
clientStream
)
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
lastFailoverStatus
=
failoverErr
.
StatusCode
lastFailoverHeaders
=
cloneHTTPHeaders
(
failoverErr
.
ResponseHeaders
)
lastFailoverBody
=
failoverErr
.
ResponseBody
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.upstream_failover_exhausted"
,
fields
...
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
lastFailoverHeaders
,
lastFailoverBody
,
streamStarted
)
return
}
lastFailoverStatus
=
failoverErr
.
StatusCode
lastFailoverHeaders
=
cloneHTTPHeaders
(
failoverErr
.
ResponseHeaders
)
lastFailoverBody
=
failoverErr
.
ResponseBody
switchCount
++
upstreamErrCode
,
upstreamErrMsg
:=
extractUpstreamErrorCodeAndMessage
(
lastFailoverBody
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
String
(
"upstream_error_code"
,
upstreamErrCode
),
zap
.
String
(
"upstream_error_message"
,
upstreamErrMsg
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.upstream_failover_switching"
,
fields
...
)
continue
}
reqLog
.
Error
(
"sora.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
return
}
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
requestPayloadHash
:=
service
.
HashUsageRequestPayload
(
body
)
inboundEndpoint
:=
GetInboundEndpoint
(
c
)
upstreamEndpoint
:=
GetUpstreamEndpoint
(
c
,
account
.
Platform
)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
Result
:
result
,
APIKey
:
apiKey
,
User
:
apiKey
.
User
,
Account
:
account
,
Subscription
:
subscription
,
InboundEndpoint
:
inboundEndpoint
,
UpstreamEndpoint
:
upstreamEndpoint
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.sora_gateway.chat_completions"
),
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
zap
.
String
(
"model"
,
reqModel
),
zap
.
Int64
(
"account_id"
,
account
.
ID
),
)
.
Error
(
"sora.record_usage_failed"
,
zap
.
Error
(
err
))
}
})
reqLog
.
Debug
(
"sora.request_completed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"switch_count"
,
switchCount
),
)
return
}
}
func
generateOpenAISessionHash
(
c
*
gin
.
Context
,
body
[]
byte
)
string
{
if
c
==
nil
{
return
""
}
sessionID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"session_id"
))
if
sessionID
==
""
{
sessionID
=
strings
.
TrimSpace
(
c
.
GetHeader
(
"conversation_id"
))
}
if
sessionID
==
""
&&
len
(
body
)
>
0
{
sessionID
=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"prompt_cache_key"
)
.
String
())
}
if
sessionID
==
""
{
return
""
}
hash
:=
sha256
.
Sum256
([]
byte
(
sessionID
))
return
hex
.
EncodeToString
(
hash
[
:
])
}
func
(
h
*
SoraGatewayHandler
)
submitUsageRecordTask
(
task
service
.
UsageRecordTask
)
{
if
task
==
nil
{
return
}
if
h
.
usageRecordWorkerPool
!=
nil
{
h
.
usageRecordWorkerPool
.
Submit
(
task
)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
defer
func
()
{
if
recovered
:=
recover
();
recovered
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.sora_gateway.chat_completions"
),
zap
.
Any
(
"panic"
,
recovered
),
)
.
Error
(
"sora.usage_record_task_panic_recovered"
)
}
}()
task
(
ctx
)
}
func
(
h
*
SoraGatewayHandler
)
handleConcurrencyError
(
c
*
gin
.
Context
,
err
error
,
slotType
string
,
streamStarted
bool
)
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
}
func
(
h
*
SoraGatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
statusCode
int
,
responseHeaders
http
.
Header
,
responseBody
[]
byte
,
streamStarted
bool
)
{
upstreamMsg
:=
service
.
ExtractUpstreamErrorMessage
(
responseBody
)
service
.
SetOpsUpstreamError
(
c
,
statusCode
,
upstreamMsg
,
""
)
status
,
errType
,
errMsg
:=
h
.
mapUpstreamError
(
statusCode
,
responseHeaders
,
responseBody
)
h
.
handleStreamingAwareError
(
c
,
status
,
errType
,
errMsg
,
streamStarted
)
}
func
(
h
*
SoraGatewayHandler
)
mapUpstreamError
(
statusCode
int
,
responseHeaders
http
.
Header
,
responseBody
[]
byte
)
(
int
,
string
,
string
)
{
if
isSoraCloudflareChallengeResponse
(
statusCode
,
responseHeaders
,
responseBody
)
{
baseMsg
:=
fmt
.
Sprintf
(
"Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry."
,
statusCode
)
return
http
.
StatusBadGateway
,
"upstream_error"
,
formatSoraCloudflareChallengeMessage
(
baseMsg
,
responseHeaders
,
responseBody
)
}
upstreamCode
,
upstreamMessage
:=
extractUpstreamErrorCodeAndMessage
(
responseBody
)
if
strings
.
EqualFold
(
upstreamCode
,
"cf_shield_429"
)
{
baseMsg
:=
"Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
formatSoraCloudflareChallengeMessage
(
baseMsg
,
responseHeaders
,
responseBody
)
}
if
shouldPassthroughSoraUpstreamMessage
(
statusCode
,
upstreamMessage
)
{
switch
statusCode
{
case
401
,
403
,
404
,
500
,
502
,
503
,
504
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
upstreamMessage
case
429
:
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
upstreamMessage
}
}
switch
statusCode
{
case
401
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream authentication failed, please contact administrator"
case
403
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream access forbidden, please contact administrator"
case
404
:
if
strings
.
EqualFold
(
upstreamCode
,
"unsupported_country_code"
)
{
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream region capability unavailable for this account, please contact administrator"
}
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream capability unavailable for this account, please contact administrator"
case
429
:
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Upstream rate limit exceeded, please retry later"
case
529
:
return
http
.
StatusServiceUnavailable
,
"upstream_error"
,
"Upstream service overloaded, please retry later"
case
500
,
502
,
503
,
504
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream service temporarily unavailable"
default
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
}
}
func
cloneHTTPHeaders
(
headers
http
.
Header
)
http
.
Header
{
if
headers
==
nil
{
return
nil
}
return
headers
.
Clone
()
}
func
extractSoraFailoverHeaderInsights
(
headers
http
.
Header
,
body
[]
byte
)
(
rayID
,
mitigated
,
contentType
string
)
{
if
headers
!=
nil
{
mitigated
=
strings
.
TrimSpace
(
headers
.
Get
(
"cf-mitigated"
))
contentType
=
strings
.
TrimSpace
(
headers
.
Get
(
"content-type"
))
if
contentType
==
""
{
contentType
=
strings
.
TrimSpace
(
headers
.
Get
(
"Content-Type"
))
}
}
rayID
=
soraerror
.
ExtractCloudflareRayID
(
headers
,
body
)
return
rayID
,
mitigated
,
contentType
}
func
isSoraCloudflareChallengeResponse
(
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
bool
{
return
soraerror
.
IsCloudflareChallengeResponse
(
statusCode
,
headers
,
body
)
}
func
shouldPassthroughSoraUpstreamMessage
(
statusCode
int
,
message
string
)
bool
{
message
=
strings
.
TrimSpace
(
message
)
if
message
==
""
{
return
false
}
if
statusCode
==
http
.
StatusForbidden
||
statusCode
==
http
.
StatusTooManyRequests
{
lower
:=
strings
.
ToLower
(
message
)
if
strings
.
Contains
(
lower
,
"<html"
)
||
strings
.
Contains
(
lower
,
"<!doctype html"
)
||
strings
.
Contains
(
lower
,
"window._cf_chl_opt"
)
{
return
false
}
}
return
true
}
func
formatSoraCloudflareChallengeMessage
(
base
string
,
headers
http
.
Header
,
body
[]
byte
)
string
{
return
soraerror
.
FormatCloudflareChallengeMessage
(
base
,
headers
,
body
)
}
func
extractUpstreamErrorCodeAndMessage
(
body
[]
byte
)
(
string
,
string
)
{
return
soraerror
.
ExtractUpstreamErrorCodeAndMessage
(
body
)
}
func
(
h
*
SoraGatewayHandler
)
handleStreamingAwareError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
,
streamStarted
bool
)
{
if
streamStarted
{
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
)
if
ok
{
errorData
:=
map
[
string
]
any
{
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"event: error
\n
data: %s
\n\n
"
,
string
(
jsonBytes
))
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
);
err
!=
nil
{
_
=
c
.
Error
(
err
)
}
flusher
.
Flush
()
}
return
}
h
.
errorResponse
(
c
,
status
,
errType
,
message
)
}
func
(
h
*
SoraGatewayHandler
)
errorResponse
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
message
,
},
})
}
// MediaProxy serves local Sora media files.
func
(
h
*
SoraGatewayHandler
)
MediaProxy
(
c
*
gin
.
Context
)
{
h
.
proxySoraMedia
(
c
,
false
)
}
// MediaProxySigned serves local Sora media files with signature verification.
func
(
h
*
SoraGatewayHandler
)
MediaProxySigned
(
c
*
gin
.
Context
)
{
h
.
proxySoraMedia
(
c
,
true
)
}
func
(
h
*
SoraGatewayHandler
)
proxySoraMedia
(
c
*
gin
.
Context
,
requireSignature
bool
)
{
rawPath
:=
c
.
Param
(
"filepath"
)
if
rawPath
==
""
{
c
.
Status
(
http
.
StatusNotFound
)
return
}
cleaned
:=
path
.
Clean
(
rawPath
)
if
!
strings
.
HasPrefix
(
cleaned
,
"/image/"
)
&&
!
strings
.
HasPrefix
(
cleaned
,
"/video/"
)
{
c
.
Status
(
http
.
StatusNotFound
)
return
}
query
:=
c
.
Request
.
URL
.
Query
()
if
requireSignature
{
if
h
.
soraMediaSigningKey
==
""
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 媒体签名未配置"
,
},
})
return
}
expiresStr
:=
strings
.
TrimSpace
(
query
.
Get
(
"expires"
))
signature
:=
strings
.
TrimSpace
(
query
.
Get
(
"sig"
))
expires
,
err
:=
strconv
.
ParseInt
(
expiresStr
,
10
,
64
)
if
err
!=
nil
||
expires
<=
time
.
Now
()
.
Unix
()
{
c
.
JSON
(
http
.
StatusUnauthorized
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"authentication_error"
,
"message"
:
"Sora 媒体签名已过期"
,
},
})
return
}
query
.
Del
(
"sig"
)
query
.
Del
(
"expires"
)
signingQuery
:=
query
.
Encode
()
if
!
service
.
VerifySoraMediaURL
(
cleaned
,
signingQuery
,
expires
,
signature
,
h
.
soraMediaSigningKey
)
{
c
.
JSON
(
http
.
StatusUnauthorized
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"authentication_error"
,
"message"
:
"Sora 媒体签名无效"
,
},
})
return
}
}
if
strings
.
TrimSpace
(
h
.
soraMediaRoot
)
==
""
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 媒体目录未配置"
,
},
})
return
}
relative
:=
strings
.
TrimPrefix
(
cleaned
,
"/"
)
localPath
:=
filepath
.
Join
(
h
.
soraMediaRoot
,
filepath
.
FromSlash
(
relative
))
if
_
,
err
:=
os
.
Stat
(
localPath
);
err
!=
nil
{
if
os
.
IsNotExist
(
err
)
{
c
.
Status
(
http
.
StatusNotFound
)
return
}
c
.
Status
(
http
.
StatusInternalServerError
)
return
}
c
.
File
(
localPath
)
}
backend/internal/handler/sora_gateway_handler_test.go
deleted
100644 → 0
View file @
dbb248df
//go:build unit
package
handler
import
(
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
var
_
service
.
SoraClient
=
(
*
stubSoraClient
)(
nil
)
var
_
service
.
AccountRepository
=
(
*
stubAccountRepo
)(
nil
)
var
_
service
.
GroupRepository
=
(
*
stubGroupRepo
)(
nil
)
var
_
service
.
UsageLogRepository
=
(
*
stubUsageLogRepo
)(
nil
)
type
stubSoraClient
struct
{
imageURLs
[]
string
}
func
(
s
*
stubSoraClient
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClient
)
UploadImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
{
return
"upload"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraVideoRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraStoryboardRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClient
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"cameo-1"
,
nil
}
func
(
s
*
stubSoraClient
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
service
.
Account
,
cameoID
string
)
(
*
service
.
SoraCameoStatus
,
error
)
{
return
&
service
.
SoraCameoStatus
{
Status
:
"finalized"
,
StatusMessage
:
"Completed"
,
DisplayNameHint
:
"Character"
,
UsernameHint
:
"user.character"
,
ProfileAssetURL
:
"https://example.com/avatar.webp"
,
},
nil
}
func
(
s
*
stubSoraClient
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
imageURL
string
)
([]
byte
,
error
)
{
return
[]
byte
(
"avatar"
),
nil
}
func
(
s
*
stubSoraClient
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"asset-pointer"
,
nil
}
func
(
s
*
stubSoraClient
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
"character-1"
,
nil
}
func
(
s
*
stubSoraClient
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
service
.
Account
,
cameoID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
service
.
Account
,
characterID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
service
.
Account
,
generationID
string
)
(
string
,
error
)
{
return
"s_post"
,
nil
}
func
(
s
*
stubSoraClient
)
DeletePost
(
ctx
context
.
Context
,
account
*
service
.
Account
,
postID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
service
.
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
return
"https://example.com/no-watermark.mp4"
,
nil
}
func
(
s
*
stubSoraClient
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
service
.
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
return
"enhanced prompt"
,
nil
}
func
(
s
*
stubSoraClient
)
GetImageTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
taskID
string
)
(
*
service
.
SoraImageTaskStatus
,
error
)
{
return
&
service
.
SoraImageTaskStatus
{
ID
:
taskID
,
Status
:
"completed"
,
URLs
:
s
.
imageURLs
},
nil
}
func
(
s
*
stubSoraClient
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
taskID
string
)
(
*
service
.
SoraVideoTaskStatus
,
error
)
{
return
&
service
.
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"completed"
,
URLs
:
s
.
imageURLs
},
nil
}
type
stubAccountRepo
struct
{
accounts
map
[
int64
]
*
service
.
Account
}
func
(
r
*
stubAccountRepo
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
if
acc
,
ok
:=
r
.
accounts
[
id
];
ok
{
return
acc
,
nil
}
return
nil
,
service
.
ErrAccountNotFound
}
func
(
r
*
stubAccountRepo
)
GetByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
service
.
Account
,
error
)
{
var
result
[]
*
service
.
Account
for
_
,
id
:=
range
ids
{
if
acc
,
ok
:=
r
.
accounts
[
id
];
ok
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
r
*
stubAccountRepo
)
ExistsByID
(
ctx
context
.
Context
,
id
int64
)
(
bool
,
error
)
{
_
,
ok
:=
r
.
accounts
[
id
]
return
ok
,
nil
}
func
(
r
*
stubAccountRepo
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
map
[
string
]
int64
{},
nil
}
func
(
r
*
stubAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulableByPlatform
(
platform
),
nil
}
func
(
r
*
stubAccountRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
AutoPauseExpiredAccounts
(
ctx
context
.
Context
,
now
time
.
Time
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepo
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulable
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulable
(),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulable
(),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulableByPlatform
(
platform
),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
listSchedulableByPlatform
(
platform
),
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
var
result
[]
service
.
Account
for
_
,
acc
:=
range
r
.
accounts
{
for
_
,
platform
:=
range
platforms
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
*
acc
)
break
}
}
}
return
result
,
nil
}
func
(
r
*
stubAccountRepo
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
r
*
stubAccountRepo
)
ListSchedulableUngroupedByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
return
r
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
func
(
r
*
stubAccountRepo
)
ListSchedulableUngroupedByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
return
r
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
r
*
stubAccountRepo
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubAccountRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
ResetQuotaUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubAccountRepo
)
listSchedulable
()
[]
service
.
Account
{
var
result
[]
service
.
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
*
acc
)
}
}
return
result
}
func
(
r
*
stubAccountRepo
)
listSchedulableByPlatform
(
platform
string
)
[]
service
.
Account
{
var
result
[]
service
.
Account
for
_
,
acc
:=
range
r
.
accounts
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
*
acc
)
}
}
return
result
}
type
stubGroupRepo
struct
{
group
*
service
.
Group
}
func
(
r
*
stubGroupRepo
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
r
.
group
,
nil
}
func
(
r
*
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
r
.
group
,
nil
}
func
(
r
*
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubGroupRepo
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
int64
,
error
)
{
return
0
,
0
,
nil
}
func
(
r
*
stubGroupRepo
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubGroupRepo
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepo
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepo
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
service
.
GroupSortOrderUpdate
)
error
{
return
nil
}
type
stubUsageLogRepo
struct
{}
func
(
s
*
stubUsageLogRepo
)
Create
(
ctx
context
.
Context
,
log
*
service
.
UsageLog
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
UsageLog
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAPIKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAPIKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountTodayStats
(
ctx
context
.
Context
,
accountID
int64
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
[]
usagestats
.
EndpointStat
{},
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUpstreamEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
[]
usagestats
.
EndpointStat
{},
nil
}
func
(
s
*
stubUsageLogRepo
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserBreakdownStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
dim
usagestats
.
UserBreakdownDimension
,
limit
int
)
([]
usagestats
.
UserBreakdownItem
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAllGroupUsageSummary
(
ctx
context
.
Context
,
todayStart
time
.
Time
)
([]
usagestats
.
GroupUsageSummary
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserSpendingRanking
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
limit
int
)
(
*
usagestats
.
UserSpendingRankingResponse
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserDashboardStats
(
ctx
context
.
Context
,
userID
int64
)
(
*
usagestats
.
UserDashboardStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAPIKeyDashboardStats
(
ctx
context
.
Context
,
apiKeyID
int64
)
(
*
usagestats
.
UserDashboardStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserModelStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetGlobalStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetStatsWithFilters
(
ctx
context
.
Context
,
filters
usagestats
.
UsageLogFilters
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
AccountUsageStatsResponse
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUserStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAPIKeyStatsAggregated
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetAccountStatsAggregated
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetModelStatsAggregated
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubUsageLogRepo
)
GetDailyStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
map
[
string
]
any
,
error
)
{
return
nil
,
nil
}
func
TestSoraGatewayHandler_ChatCompletions
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
,
Gateway
:
config
.
GatewayConfig
{
SoraStreamMode
:
"force"
,
MaxAccountSwitches
:
1
,
Scheduling
:
config
.
GatewaySchedulingConfig
{
LoadBatchEnabled
:
false
,
},
},
Concurrency
:
config
.
ConcurrencyConfig
{
PingInterval
:
0
},
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
BaseURL
:
"https://sora.test"
,
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
account
:=
&
service
.
Account
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Priority
:
1
}
accountRepo
:=
&
stubAccountRepo
{
accounts
:
map
[
int64
]
*
service
.
Account
{
account
.
ID
:
account
}}
group
:=
&
service
.
Group
{
ID
:
1
,
Platform
:
service
.
PlatformSora
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
}
groupRepo
:=
&
stubGroupRepo
{
group
:
group
}
usageLogRepo
:=
&
stubUsageLogRepo
{}
deferredService
:=
service
.
NewDeferredService
(
accountRepo
,
nil
,
0
)
billingService
:=
service
.
NewBillingService
(
cfg
,
nil
)
concurrencyService
:=
service
.
NewConcurrencyService
(
testutil
.
StubConcurrencyCache
{})
billingCacheService
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
cfg
)
t
.
Cleanup
(
func
()
{
billingCacheService
.
Stop
()
})
gatewayService
:=
service
.
NewGatewayService
(
accountRepo
,
groupRepo
,
usageLogRepo
,
nil
,
nil
,
nil
,
nil
,
testutil
.
StubGatewayCache
{},
cfg
,
nil
,
concurrencyService
,
billingService
,
nil
,
billingCacheService
,
nil
,
nil
,
deferredService
,
nil
,
testutil
.
StubSessionLimitCache
{},
nil
,
// rpmCache
nil
,
// digestStore
nil
,
// settingService
nil
,
// tlsFPProfileService
nil
,
// channelService
nil
,
// resolver
)
soraClient
:=
&
stubSoraClient
{
imageURLs
:
[]
string
{
"https://example.com/a.png"
}}
soraGatewayService
:=
service
.
NewSoraGatewayService
(
soraClient
,
nil
,
nil
,
cfg
)
handler
:=
NewSoraGatewayHandler
(
gatewayService
,
soraGatewayService
,
concurrencyService
,
billingCacheService
,
nil
,
cfg
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
`{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/sora/v1/chat/completions"
,
strings
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
apiKey
:=
&
service
.
APIKey
{
ID
:
1
,
UserID
:
1
,
Status
:
service
.
StatusActive
,
GroupID
:
&
group
.
ID
,
User
:
&
service
.
User
{
ID
:
1
,
Concurrency
:
1
,
Status
:
service
.
StatusActive
},
Group
:
group
,
}
c
.
Set
(
string
(
middleware
.
ContextKeyAPIKey
),
apiKey
)
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
UserID
:
apiKey
.
UserID
,
Concurrency
:
apiKey
.
User
.
Concurrency
})
handler
.
ChatCompletions
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
NotEmpty
(
t
,
resp
[
"media_url"
])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func
TestSoraHandler_StreamForcing
(
t
*
testing
.
T
)
{
// 测试 1:stream=false 时 sjson 强制修改为 true
body
:=
[]
byte
(
`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`
)
clientStream
:=
gjson
.
GetBytes
(
body
,
"stream"
)
.
Bool
()
require
.
False
(
t
,
clientStream
)
newBody
,
err
:=
sjson
.
SetBytes
(
body
,
"stream"
,
true
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
gjson
.
GetBytes
(
newBody
,
"stream"
)
.
Bool
())
// 测试 2:stream=true 时不修改
body2
:=
[]
byte
(
`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`
)
require
.
True
(
t
,
gjson
.
GetBytes
(
body2
,
"stream"
)
.
Bool
())
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
body3
:=
[]
byte
(
`{"model":"sora","messages":[{"role":"user","content":"test"}]}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body3
,
"stream"
)
.
Bool
())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func
TestSoraHandler_ValidationExtraction
(
t
*
testing
.
T
)
{
// model 缺失
body
:=
[]
byte
(
`{"messages":[{"role":"user","content":"test"}]}`
)
modelResult
:=
gjson
.
GetBytes
(
body
,
"model"
)
require
.
True
(
t
,
!
modelResult
.
Exists
()
||
modelResult
.
Type
!=
gjson
.
String
||
modelResult
.
String
()
==
""
)
// model 为数字 → 类型不是 gjson.String,应被拒绝
body1b
:=
[]
byte
(
`{"model":123,"messages":[{"role":"user","content":"test"}]}`
)
modelResult1b
:=
gjson
.
GetBytes
(
body1b
,
"model"
)
require
.
True
(
t
,
modelResult1b
.
Exists
())
require
.
NotEqual
(
t
,
gjson
.
String
,
modelResult1b
.
Type
)
// messages 缺失
body2
:=
[]
byte
(
`{"model":"sora"}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body2
,
"messages"
)
.
IsArray
())
// messages 不是 JSON 数组(字符串)
body3
:=
[]
byte
(
`{"model":"sora","messages":"not array"}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body3
,
"messages"
)
.
IsArray
())
// messages 是对象而非数组 → IsArray 返回 false
body4
:=
[]
byte
(
`{"model":"sora","messages":{}}`
)
require
.
False
(
t
,
gjson
.
GetBytes
(
body4
,
"messages"
)
.
IsArray
())
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
body5
:=
[]
byte
(
`{"model":"sora","messages":[]}`
)
msgsResult
:=
gjson
.
GetBytes
(
body5
,
"messages"
)
require
.
True
(
t
,
msgsResult
.
IsArray
())
require
.
Equal
(
t
,
0
,
len
(
msgsResult
.
Array
()))
// 非法 JSON 被 gjson.ValidBytes 拦截
require
.
False
(
t
,
gjson
.
ValidBytes
([]
byte
(
`{invalid`
)))
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func
TestGenerateOpenAISessionHash_WithBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
// 从 body 提取 prompt_cache_key
body
:=
[]
byte
(
`{"model":"sora","prompt_cache_key":"session-abc"}`
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
"POST"
,
"/"
,
nil
)
hash
:=
generateOpenAISessionHash
(
c
,
body
)
require
.
NotEmpty
(
t
,
hash
)
// 无 prompt_cache_key 且无 header → 空 hash
body2
:=
[]
byte
(
`{"model":"sora"}`
)
hash2
:=
generateOpenAISessionHash
(
c
,
body2
)
require
.
Empty
(
t
,
hash2
)
// header 优先于 body
c
.
Request
.
Header
.
Set
(
"session_id"
,
"from-header"
)
hash3
:=
generateOpenAISessionHash
(
c
,
body
)
require
.
NotEmpty
(
t
,
hash3
)
require
.
NotEqual
(
t
,
hash
,
hash3
)
// 不同来源应产生不同 hash
}
func
TestSoraHandleStreamingAwareError_JSONEscaping
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
errType
string
message
string
}{
{
name
:
"包含双引号"
,
errType
:
"upstream_error"
,
message
:
`upstream returned "invalid" payload`
,
},
{
name
:
"包含换行和制表符"
,
errType
:
"rate_limit_error"
,
message
:
"line1
\n
line2
\t
tab"
,
},
{
name
:
"包含反斜杠"
,
errType
:
"upstream_error"
,
message
:
`path C:\Users\test\file.txt not found`
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleStreamingAwareError
(
c
,
http
.
StatusBadGateway
,
tt
.
errType
,
tt
.
message
,
true
)
body
:=
w
.
Body
.
String
()
require
.
True
(
t
,
strings
.
HasPrefix
(
body
,
"event: error
\n
"
),
"应以 SSE error 事件开头"
)
require
.
True
(
t
,
strings
.
HasSuffix
(
body
,
"
\n\n
"
),
"应以 SSE 结束分隔符结尾"
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
body
,
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
,
"SSE 错误事件应包含 event 行和 data 行"
)
require
.
Equal
(
t
,
"event: error"
,
lines
[
0
])
require
.
True
(
t
,
strings
.
HasPrefix
(
lines
[
1
],
"data: "
),
"第二行应为 data 前缀"
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
),
"data 行必须是合法 JSON"
)
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
,
"JSON 中应包含 error 对象"
)
require
.
Equal
(
t
,
tt
.
errType
,
errorObj
[
"type"
])
require
.
Equal
(
t
,
tt
.
message
,
errorObj
[
"message"
])
})
}
}
func
TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
SoraGatewayHandler
{}
resp
:=
[]
byte
(
`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`
)
h
.
handleFailoverExhausted
(
c
,
http
.
StatusBadGateway
,
nil
,
resp
,
true
)
body
:=
w
.
Body
.
String
()
require
.
True
(
t
,
strings
.
HasPrefix
(
body
,
"event: error
\n
"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
body
,
"
\n\n
"
))
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
body
,
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
require
.
Equal
(
t
,
"invalid
\"
prompt
\"\n
line2"
,
errorObj
[
"message"
])
}
func
TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-ray"
,
"9d01b0e9ecc35829-SEA"
)
body
:=
[]
byte
(
`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
http
.
StatusForbidden
,
headers
,
body
,
true
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
w
.
Body
.
String
(),
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
msg
,
_
:=
errorObj
[
"message"
]
.
(
string
)
require
.
Contains
(
t
,
msg
,
"Cloudflare challenge"
)
require
.
Contains
(
t
,
msg
,
"cf-ray: 9d01b0e9ecc35829-SEA"
)
}
func
TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-ray"
,
"9d03b68c086027a1-SEA"
)
body
:=
[]
byte
(
`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
http
.
StatusTooManyRequests
,
headers
,
body
,
true
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
w
.
Body
.
String
(),
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"rate_limit_error"
,
errorObj
[
"type"
])
msg
,
_
:=
errorObj
[
"message"
]
.
(
string
)
require
.
Contains
(
t
,
msg
,
"Cloudflare shield"
)
require
.
Contains
(
t
,
msg
,
"cf-ray: 9d03b68c086027a1-SEA"
)
}
func
TestExtractSoraFailoverHeaderInsights
(
t
*
testing
.
T
)
{
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-mitigated"
,
"challenge"
)
headers
.
Set
(
"content-type"
,
"text/html"
)
body
:=
[]
byte
(
`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
headers
,
body
)
require
.
Equal
(
t
,
"9cff2d62d83bb98d"
,
rayID
)
require
.
Equal
(
t
,
"challenge"
,
mitigated
)
require
.
Equal
(
t
,
"text/html"
,
contentType
)
}
backend/internal/handler/usage_record_submit_task_test.go
View file @
62e80c60
...
@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
...
@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
})
require
.
True
(
t
,
called
.
Load
(),
"panic 后后续任务应仍可执行"
)
require
.
True
(
t
,
called
.
Load
(),
"panic 后后续任务应仍可执行"
)
}
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool
(
t
*
testing
.
T
)
{
pool
:=
newUsageRecordTestPool
(
t
)
h
:=
&
SoraGatewayHandler
{
usageRecordWorkerPool
:
pool
}
done
:=
make
(
chan
struct
{})
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
close
(
done
)
})
select
{
case
<-
done
:
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"task not executed"
)
}
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback
(
t
*
testing
.
T
)
{
h
:=
&
SoraGatewayHandler
{}
var
called
atomic
.
Bool
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
_
,
ok
:=
ctx
.
Deadline
();
!
ok
{
t
.
Fatal
(
"expected deadline in fallback context"
)
}
called
.
Store
(
true
)
})
require
.
True
(
t
,
called
.
Load
())
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask
(
t
*
testing
.
T
)
{
h
:=
&
SoraGatewayHandler
{}
require
.
NotPanics
(
t
,
func
()
{
h
.
submitUsageRecordTask
(
nil
)
})
}
func
TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered
(
t
*
testing
.
T
)
{
h
:=
&
SoraGatewayHandler
{}
var
called
atomic
.
Bool
require
.
NotPanics
(
t
,
func
()
{
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
panic
(
"usage task panic"
)
})
})
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
called
.
Store
(
true
)
})
require
.
True
(
t
,
called
.
Load
(),
"panic 后后续任务应仍可执行"
)
}
backend/internal/handler/wire.go
View file @
62e80c60
...
@@ -86,8 +86,6 @@ func ProvideHandlers(
...
@@ -86,8 +86,6 @@ func ProvideHandlers(
adminHandlers
*
AdminHandlers
,
adminHandlers
*
AdminHandlers
,
gatewayHandler
*
GatewayHandler
,
gatewayHandler
*
GatewayHandler
,
openaiGatewayHandler
*
OpenAIGatewayHandler
,
openaiGatewayHandler
*
OpenAIGatewayHandler
,
soraGatewayHandler
*
SoraGatewayHandler
,
soraClientHandler
*
SoraClientHandler
,
settingHandler
*
SettingHandler
,
settingHandler
*
SettingHandler
,
totpHandler
*
TotpHandler
,
totpHandler
*
TotpHandler
,
_
*
service
.
IdempotencyCoordinator
,
_
*
service
.
IdempotencyCoordinator
,
...
@@ -104,8 +102,6 @@ func ProvideHandlers(
...
@@ -104,8 +102,6 @@ func ProvideHandlers(
Admin
:
adminHandlers
,
Admin
:
adminHandlers
,
Gateway
:
gatewayHandler
,
Gateway
:
gatewayHandler
,
OpenAIGateway
:
openaiGatewayHandler
,
OpenAIGateway
:
openaiGatewayHandler
,
SoraGateway
:
soraGatewayHandler
,
SoraClient
:
soraClientHandler
,
Setting
:
settingHandler
,
Setting
:
settingHandler
,
Totp
:
totpHandler
,
Totp
:
totpHandler
,
}
}
...
@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
...
@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler
,
NewAnnouncementHandler
,
NewGatewayHandler
,
NewGatewayHandler
,
NewOpenAIGatewayHandler
,
NewOpenAIGatewayHandler
,
NewSoraGatewayHandler
,
NewTotpHandler
,
NewTotpHandler
,
ProvideSettingHandler
,
ProvideSettingHandler
,
...
...
backend/internal/pkg/openai/oauth.go
View file @
62e80c60
...
@@ -17,8 +17,6 @@ import (
...
@@ -17,8 +17,6 @@ import (
const
(
const
(
// OAuth Client ID for OpenAI (Codex CLI official)
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID
=
"app_EMoamEEZ73f0CkXaXp7hrann"
ClientID
=
"app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID
=
"app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
// OAuth endpoints
AuthorizeURL
=
"https://auth.openai.com/oauth/authorize"
AuthorizeURL
=
"https://auth.openai.com/oauth/authorize"
...
@@ -39,8 +37,6 @@ const (
...
@@ -39,8 +37,6 @@ const (
const
(
const
(
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI
=
"openai"
OAuthPlatformOpenAI
=
"openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora
=
"sora"
)
)
// OAuthSession stores OAuth flow state for OpenAI
// OAuthSession stores OAuth flow state for OpenAI
...
@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
...
@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
}
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func
OAuthClientConfigByPlatform
(
platform
string
)
(
clientID
string
,
codexFlow
bool
)
{
func
OAuthClientConfigByPlatform
(
platform
string
)
(
clientID
string
,
codexFlow
bool
)
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
platform
))
{
return
ClientID
,
true
case
OAuthPlatformSora
:
return
ClientID
,
false
default
:
return
ClientID
,
true
}
}
}
// TokenRequest represents the token exchange request body
// TokenRequest represents the token exchange request body
...
...
backend/internal/pkg/openai/oauth_test.go
View file @
62e80c60
...
@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
...
@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func
TestBuildAuthorizationURLForPlatform_Sora
(
t
*
testing
.
T
)
{
authURL
:=
BuildAuthorizationURLForPlatform
(
"state-2"
,
"challenge-2"
,
DefaultRedirectURI
,
OAuthPlatformSora
)
parsed
,
err
:=
url
.
Parse
(
authURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"Parse URL failed: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
got
:=
q
.
Get
(
"client_id"
);
got
!=
ClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)"
,
got
,
ClientID
)
}
if
got
:=
q
.
Get
(
"codex_cli_simplified_flow"
);
got
!=
""
{
t
.
Fatalf
(
"codex flow should be empty for sora, got=%q"
,
got
)
}
if
got
:=
q
.
Get
(
"id_token_add_organizations"
);
got
!=
"true"
{
t
.
Fatalf
(
"id_token_add_organizations mismatch: got=%q want=true"
,
got
)
}
}
backend/internal/repository/account_repo.go
View file @
62e80c60
...
@@ -1692,20 +1692,13 @@ func itoa(v int) string {
...
@@ -1692,20 +1692,13 @@ func itoa(v int) string {
}
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func
(
r
*
accountRepository
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
service
.
Account
,
error
)
{
accounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
accounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
Where
(
Where
(
dbaccount
.
PlatformEQ
(
"sora"
),
// 限定平台为 sora
dbaccount
.
DeletedAtIsNil
(),
dbaccount
.
DeletedAtIsNil
(),
func
(
s
*
entsql
.
Selector
)
{
func
(
s
*
entsql
.
Selector
)
{
path
:=
sqljson
.
Path
(
key
)
path
:=
sqljson
.
Path
(
key
)
...
...
backend/internal/repository/api_key_repo.go
View file @
62e80c60
...
@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
...
@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldImagePrice1k
,
group
.
FieldImagePrice1k
,
group
.
FieldImagePrice2k
,
group
.
FieldImagePrice2k
,
group
.
FieldImagePrice4k
,
group
.
FieldImagePrice4k
,
group
.
FieldSoraImagePrice360
,
group
.
FieldSoraImagePrice540
,
group
.
FieldSoraVideoPricePerRequest
,
group
.
FieldSoraVideoPricePerRequestHd
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldFallbackGroupID
,
group
.
FieldFallbackGroupID
,
group
.
FieldFallbackGroupIDOnInvalidRequest
,
group
.
FieldFallbackGroupIDOnInvalidRequest
,
...
@@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User {
...
@@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User {
Balance
:
u
.
Balance
,
Balance
:
u
.
Balance
,
Concurrency
:
u
.
Concurrency
,
Concurrency
:
u
.
Concurrency
,
Status
:
u
.
Status
,
Status
:
u
.
Status
,
SoraStorageQuotaBytes
:
u
.
SoraStorageQuotaBytes
,
SoraStorageUsedBytes
:
u
.
SoraStorageUsedBytes
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
...
@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
...
@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K
:
g
.
ImagePrice1k
,
ImagePrice1K
:
g
.
ImagePrice1k
,
ImagePrice2K
:
g
.
ImagePrice2k
,
ImagePrice2K
:
g
.
ImagePrice2k
,
ImagePrice4K
:
g
.
ImagePrice4k
,
ImagePrice4K
:
g
.
ImagePrice4k
,
SoraImagePrice360
:
g
.
SoraImagePrice360
,
SoraImagePrice540
:
g
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
g
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHd
,
SoraStorageQuotaBytes
:
g
.
SoraStorageQuotaBytes
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
FallbackGroupID
:
g
.
FallbackGroupID
,
...
...
backend/internal/repository/group_repo.go
View file @
62e80c60
...
@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
...
@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k
(
groupIn
.
ImagePrice1K
)
.
SetNillableImagePrice1k
(
groupIn
.
ImagePrice1K
)
.
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetNillableSoraImagePrice360
(
groupIn
.
SoraImagePrice360
)
.
SetNillableSoraImagePrice540
(
groupIn
.
SoraImagePrice540
)
.
SetNillableSoraVideoPricePerRequest
(
groupIn
.
SoraVideoPricePerRequest
)
.
SetNillableSoraVideoPricePerRequestHd
(
groupIn
.
SoraVideoPricePerRequestHD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetSoraStorageQuotaBytes
(
groupIn
.
SoraStorageQuotaBytes
)
.
SetAllowMessagesDispatch
(
groupIn
.
AllowMessagesDispatch
)
.
SetAllowMessagesDispatch
(
groupIn
.
AllowMessagesDispatch
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
...
@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
...
@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k
(
groupIn
.
ImagePrice1K
)
.
SetNillableImagePrice1k
(
groupIn
.
ImagePrice1K
)
.
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice2k
(
groupIn
.
ImagePrice2K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetNillableSoraImagePrice360
(
groupIn
.
SoraImagePrice360
)
.
SetNillableSoraImagePrice540
(
groupIn
.
SoraImagePrice540
)
.
SetNillableSoraVideoPricePerRequest
(
groupIn
.
SoraVideoPricePerRequest
)
.
SetNillableSoraVideoPricePerRequestHd
(
groupIn
.
SoraVideoPricePerRequestHD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
.
SetSoraStorageQuotaBytes
(
groupIn
.
SoraStorageQuotaBytes
)
.
SetAllowMessagesDispatch
(
groupIn
.
AllowMessagesDispatch
)
.
SetAllowMessagesDispatch
(
groupIn
.
AllowMessagesDispatch
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequireOauthOnly
(
groupIn
.
RequireOAuthOnly
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
SetRequirePrivacySet
(
groupIn
.
RequirePrivacySet
)
.
...
...
backend/internal/repository/openai_oauth_service_test.go
View file @
62e80c60
...
@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
...
@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require
.
Equal
(
s
.
T
(),
[]
string
{
openai
.
ClientID
},
seenClientIDs
)
require
.
Equal
(
s
.
T
(),
[]
string
{
openai
.
ClientID
},
seenClientIDs
)
}
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_UseSoraClientID
()
{
var
seenClientIDs
[]
string
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
err
:=
r
.
ParseForm
();
err
!=
nil
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
clientID
:=
r
.
PostForm
.
Get
(
"client_id"
)
seenClientIDs
=
append
(
seenClientIDs
,
clientID
)
if
clientID
==
openai
.
SoraClientID
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`
)
return
}
w
.
WriteHeader
(
http
.
StatusBadRequest
)
}))
resp
,
err
:=
s
.
svc
.
RefreshTokenWithClientID
(
s
.
ctx
,
"rt"
,
""
,
openai
.
SoraClientID
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshTokenWithClientID"
)
require
.
Equal
(
s
.
T
(),
"at-sora"
,
resp
.
AccessToken
)
require
.
Equal
(
s
.
T
(),
[]
string
{
openai
.
SoraClientID
},
seenClientIDs
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_UseProvidedClientID
()
{
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_UseProvidedClientID
()
{
const
customClientID
=
"custom-client-id"
const
customClientID
=
"custom-client-id"
var
seenClientIDs
[]
string
var
seenClientIDs
[]
string
...
@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
...
@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
}
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestExchangeCode_UseProvidedClientID
()
{
func
(
s
*
OpenAIOAuthServiceSuite
)
TestExchangeCode_UseProvidedClientID
()
{
wantClientID
:=
openai
.
SoraC
lient
ID
wantClientID
:=
"custom-exchange-c
lient
-id"
errCh
:=
make
(
chan
string
,
1
)
errCh
:=
make
(
chan
string
,
1
)
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
=
r
.
ParseForm
()
_
=
r
.
ParseForm
()
...
...
backend/internal/repository/sora_account_repo.go
deleted
100644 → 0
View file @
dbb248df
package
repository
import
(
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type
soraAccountRepository
struct
{
sql
*
sql
.
DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func
NewSoraAccountRepository
(
sqlDB
*
sql
.
DB
)
service
.
SoraAccountRepository
{
return
&
soraAccountRepository
{
sql
:
sqlDB
}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func
(
r
*
soraAccountRepository
)
Upsert
(
ctx
context
.
Context
,
accountID
int64
,
updates
map
[
string
]
any
)
error
{
accessToken
,
accessOK
:=
updates
[
"access_token"
]
.
(
string
)
refreshToken
,
refreshOK
:=
updates
[
"refresh_token"
]
.
(
string
)
sessionToken
,
sessionOK
:=
updates
[
"session_token"
]
.
(
string
)
if
!
accessOK
||
accessToken
==
""
||
!
refreshOK
||
refreshToken
==
""
{
if
!
sessionOK
{
return
errors
.
New
(
"缺少 access_token/refresh_token,且未提供可更新字段"
)
}
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`
,
accountID
,
sessionToken
)
if
err
!=
nil
{
return
err
}
rows
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
err
}
if
rows
==
0
{
return
errors
.
New
(
"sora_accounts 记录不存在,无法仅更新 session_token"
)
}
return
nil
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`
,
accountID
,
accessToken
,
refreshToken
,
sessionToken
)
return
err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func
(
r
*
soraAccountRepository
)
GetByAccountID
(
ctx
context
.
Context
,
accountID
int64
)
(
*
service
.
SoraAccount
,
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`
,
accountID
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
if
!
rows
.
Next
()
{
return
nil
,
nil
// 记录不存在
}
var
sa
service
.
SoraAccount
if
err
:=
rows
.
Scan
(
&
sa
.
AccountID
,
&
sa
.
AccessToken
,
&
sa
.
RefreshToken
,
&
sa
.
SessionToken
);
err
!=
nil
{
return
nil
,
err
}
return
&
sa
,
nil
}
// Delete 删除 Sora 账号扩展信息
func
(
r
*
soraAccountRepository
)
Delete
(
ctx
context
.
Context
,
accountID
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
DELETE FROM sora_accounts WHERE account_id = $1
`
,
accountID
)
return
err
}
backend/internal/repository/sora_generation_repo.go
deleted
100644 → 0
View file @
dbb248df
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type
soraGenerationRepository
struct
{
sql
*
sql
.
DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func
NewSoraGenerationRepository
(
sqlDB
*
sql
.
DB
)
service
.
SoraGenerationRepository
{
return
&
soraGenerationRepository
{
sql
:
sqlDB
}
}
func
(
r
*
soraGenerationRepository
)
Create
(
ctx
context
.
Context
,
gen
*
service
.
SoraGeneration
)
error
{
mediaURLsJSON
,
_
:=
json
.
Marshal
(
gen
.
MediaURLs
)
s3KeysJSON
,
_
:=
json
.
Marshal
(
gen
.
S3ObjectKeys
)
err
:=
r
.
sql
.
QueryRowContext
(
ctx
,
`
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`
,
gen
.
UserID
,
gen
.
APIKeyID
,
gen
.
Model
,
gen
.
Prompt
,
gen
.
MediaType
,
gen
.
Status
,
gen
.
MediaURL
,
mediaURLsJSON
,
gen
.
FileSizeBytes
,
gen
.
StorageType
,
s3KeysJSON
,
gen
.
UpstreamTaskID
,
gen
.
ErrorMessage
,
)
.
Scan
(
&
gen
.
ID
,
&
gen
.
CreatedAt
)
return
err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func
(
r
*
soraGenerationRepository
)
CreatePendingWithLimit
(
ctx
context
.
Context
,
gen
*
service
.
SoraGeneration
,
activeStatuses
[]
string
,
maxActive
int64
,
)
error
{
if
gen
==
nil
{
return
fmt
.
Errorf
(
"generation is nil"
)
}
if
maxActive
<=
0
{
return
r
.
Create
(
ctx
,
gen
)
}
if
len
(
activeStatuses
)
==
0
{
activeStatuses
=
[]
string
{
service
.
SoraGenStatusPending
,
service
.
SoraGenStatusGenerating
}
}
tx
,
err
:=
r
.
sql
.
BeginTx
(
ctx
,
nil
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if
_
,
err
:=
tx
.
ExecContext
(
ctx
,
`SELECT pg_advisory_xact_lock($1)`
,
gen
.
UserID
);
err
!=
nil
{
return
err
}
placeholders
:=
make
([]
string
,
len
(
activeStatuses
))
args
:=
make
([]
any
,
0
,
1
+
len
(
activeStatuses
))
args
=
append
(
args
,
gen
.
UserID
)
for
i
,
s
:=
range
activeStatuses
{
placeholders
[
i
]
=
fmt
.
Sprintf
(
"$%d"
,
i
+
2
)
args
=
append
(
args
,
s
)
}
countQuery
:=
fmt
.
Sprintf
(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`
,
strings
.
Join
(
placeholders
,
","
),
)
var
activeCount
int64
if
err
:=
tx
.
QueryRowContext
(
ctx
,
countQuery
,
args
...
)
.
Scan
(
&
activeCount
);
err
!=
nil
{
return
err
}
if
activeCount
>=
maxActive
{
return
service
.
ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON
,
_
:=
json
.
Marshal
(
gen
.
MediaURLs
)
s3KeysJSON
,
_
:=
json
.
Marshal
(
gen
.
S3ObjectKeys
)
if
err
:=
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`
,
gen
.
UserID
,
gen
.
APIKeyID
,
gen
.
Model
,
gen
.
Prompt
,
gen
.
MediaType
,
gen
.
Status
,
gen
.
MediaURL
,
mediaURLsJSON
,
gen
.
FileSizeBytes
,
gen
.
StorageType
,
s3KeysJSON
,
gen
.
UpstreamTaskID
,
gen
.
ErrorMessage
,
)
.
Scan
(
&
gen
.
ID
,
&
gen
.
CreatedAt
);
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
(
r
*
soraGenerationRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
SoraGeneration
,
error
)
{
gen
:=
&
service
.
SoraGeneration
{}
var
mediaURLsJSON
,
s3KeysJSON
[]
byte
var
completedAt
sql
.
NullTime
var
apiKeyID
sql
.
NullInt64
err
:=
r
.
sql
.
QueryRowContext
(
ctx
,
`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`
,
id
)
.
Scan
(
&
gen
.
ID
,
&
gen
.
UserID
,
&
apiKeyID
,
&
gen
.
Model
,
&
gen
.
Prompt
,
&
gen
.
MediaType
,
&
gen
.
Status
,
&
gen
.
MediaURL
,
&
mediaURLsJSON
,
&
gen
.
FileSizeBytes
,
&
gen
.
StorageType
,
&
s3KeysJSON
,
&
gen
.
UpstreamTaskID
,
&
gen
.
ErrorMessage
,
&
gen
.
CreatedAt
,
&
completedAt
,
)
if
err
!=
nil
{
if
err
==
sql
.
ErrNoRows
{
return
nil
,
fmt
.
Errorf
(
"生成记录不存在"
)
}
return
nil
,
err
}
if
apiKeyID
.
Valid
{
gen
.
APIKeyID
=
&
apiKeyID
.
Int64
}
if
completedAt
.
Valid
{
gen
.
CompletedAt
=
&
completedAt
.
Time
}
_
=
json
.
Unmarshal
(
mediaURLsJSON
,
&
gen
.
MediaURLs
)
_
=
json
.
Unmarshal
(
s3KeysJSON
,
&
gen
.
S3ObjectKeys
)
return
gen
,
nil
}
func
(
r
*
soraGenerationRepository
)
Update
(
ctx
context
.
Context
,
gen
*
service
.
SoraGeneration
)
error
{
mediaURLsJSON
,
_
:=
json
.
Marshal
(
gen
.
MediaURLs
)
s3KeysJSON
,
_
:=
json
.
Marshal
(
gen
.
S3ObjectKeys
)
var
completedAt
*
time
.
Time
if
gen
.
CompletedAt
!=
nil
{
completedAt
=
gen
.
CompletedAt
}
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`
,
gen
.
ID
,
gen
.
Status
,
gen
.
MediaURL
,
mediaURLsJSON
,
gen
.
FileSizeBytes
,
gen
.
StorageType
,
s3KeysJSON
,
gen
.
UpstreamTaskID
,
gen
.
ErrorMessage
,
completedAt
,
)
return
err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func
(
r
*
soraGenerationRepository
)
UpdateGeneratingIfPending
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
(
bool
,
error
)
{
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`
,
id
,
service
.
SoraGenStatusGenerating
,
upstreamTaskID
,
service
.
SoraGenStatusPending
,
)
if
err
!=
nil
{
return
false
,
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func
(
r
*
soraGenerationRepository
)
UpdateCompletedIfActive
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
completedAt
time
.
Time
,
)
(
bool
,
error
)
{
mediaURLsJSON
,
_
:=
json
.
Marshal
(
mediaURLs
)
s3KeysJSON
,
_
:=
json
.
Marshal
(
s3Keys
)
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`
,
id
,
service
.
SoraGenStatusCompleted
,
mediaURL
,
mediaURLsJSON
,
fileSizeBytes
,
storageType
,
s3KeysJSON
,
completedAt
,
service
.
SoraGenStatusPending
,
service
.
SoraGenStatusGenerating
,
)
if
err
!=
nil
{
return
false
,
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func
(
r
*
soraGenerationRepository
)
UpdateFailedIfActive
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
,
completedAt
time
.
Time
,
)
(
bool
,
error
)
{
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`
,
id
,
service
.
SoraGenStatusFailed
,
errMsg
,
completedAt
,
service
.
SoraGenStatusPending
,
service
.
SoraGenStatusGenerating
,
)
if
err
!=
nil
{
return
false
,
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func
(
r
*
soraGenerationRepository
)
UpdateCancelledIfActive
(
ctx
context
.
Context
,
id
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
{
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`
,
id
,
service
.
SoraGenStatusCancelled
,
completedAt
,
service
.
SoraGenStatusPending
,
service
.
SoraGenStatusGenerating
,
)
if
err
!=
nil
{
return
false
,
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
func
(
r
*
soraGenerationRepository
)
UpdateStorageIfCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
)
(
bool
,
error
)
{
mediaURLsJSON
,
_
:=
json
.
Marshal
(
mediaURLs
)
s3KeysJSON
,
_
:=
json
.
Marshal
(
s3Keys
)
result
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`
,
id
,
mediaURL
,
mediaURLsJSON
,
fileSizeBytes
,
storageType
,
s3KeysJSON
,
service
.
SoraGenStatusCompleted
,
)
if
err
!=
nil
{
return
false
,
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
false
,
err
}
return
affected
>
0
,
nil
}
func
(
r
*
soraGenerationRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`DELETE FROM sora_generations WHERE id = $1`
,
id
)
return
err
}
func
(
r
*
soraGenerationRepository
)
List
(
ctx
context
.
Context
,
params
service
.
SoraGenerationListParams
)
([]
*
service
.
SoraGeneration
,
int64
,
error
)
{
// 构建 WHERE 条件
conditions
:=
[]
string
{
"user_id = $1"
}
args
:=
[]
any
{
params
.
UserID
}
argIdx
:=
2
if
params
.
Status
!=
""
{
// 支持逗号分隔的多状态
statuses
:=
strings
.
Split
(
params
.
Status
,
","
)
placeholders
:=
make
([]
string
,
len
(
statuses
))
for
i
,
s
:=
range
statuses
{
placeholders
[
i
]
=
fmt
.
Sprintf
(
"$%d"
,
argIdx
)
args
=
append
(
args
,
strings
.
TrimSpace
(
s
))
argIdx
++
}
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"status IN (%s)"
,
strings
.
Join
(
placeholders
,
","
)))
}
if
params
.
StorageType
!=
""
{
storageTypes
:=
strings
.
Split
(
params
.
StorageType
,
","
)
placeholders
:=
make
([]
string
,
len
(
storageTypes
))
for
i
,
s
:=
range
storageTypes
{
placeholders
[
i
]
=
fmt
.
Sprintf
(
"$%d"
,
argIdx
)
args
=
append
(
args
,
strings
.
TrimSpace
(
s
))
argIdx
++
}
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"storage_type IN (%s)"
,
strings
.
Join
(
placeholders
,
","
)))
}
if
params
.
MediaType
!=
""
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"media_type = $%d"
,
argIdx
))
args
=
append
(
args
,
params
.
MediaType
)
argIdx
++
}
whereClause
:=
"WHERE "
+
strings
.
Join
(
conditions
,
" AND "
)
// 计数
var
total
int64
countQuery
:=
fmt
.
Sprintf
(
"SELECT COUNT(*) FROM sora_generations %s"
,
whereClause
)
if
err
:=
r
.
sql
.
QueryRowContext
(
ctx
,
countQuery
,
args
...
)
.
Scan
(
&
total
);
err
!=
nil
{
return
nil
,
0
,
err
}
// 分页查询
offset
:=
(
params
.
Page
-
1
)
*
params
.
PageSize
listQuery
:=
fmt
.
Sprintf
(
`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`
,
whereClause
,
argIdx
,
argIdx
+
1
)
args
=
append
(
args
,
params
.
PageSize
,
offset
)
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
listQuery
,
args
...
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
results
[]
*
service
.
SoraGeneration
for
rows
.
Next
()
{
gen
:=
&
service
.
SoraGeneration
{}
var
mediaURLsJSON
,
s3KeysJSON
[]
byte
var
completedAt
sql
.
NullTime
var
apiKeyID
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
gen
.
ID
,
&
gen
.
UserID
,
&
apiKeyID
,
&
gen
.
Model
,
&
gen
.
Prompt
,
&
gen
.
MediaType
,
&
gen
.
Status
,
&
gen
.
MediaURL
,
&
mediaURLsJSON
,
&
gen
.
FileSizeBytes
,
&
gen
.
StorageType
,
&
s3KeysJSON
,
&
gen
.
UpstreamTaskID
,
&
gen
.
ErrorMessage
,
&
gen
.
CreatedAt
,
&
completedAt
,
);
err
!=
nil
{
return
nil
,
0
,
err
}
if
apiKeyID
.
Valid
{
gen
.
APIKeyID
=
&
apiKeyID
.
Int64
}
if
completedAt
.
Valid
{
gen
.
CompletedAt
=
&
completedAt
.
Time
}
_
=
json
.
Unmarshal
(
mediaURLsJSON
,
&
gen
.
MediaURLs
)
_
=
json
.
Unmarshal
(
s3KeysJSON
,
&
gen
.
S3ObjectKeys
)
results
=
append
(
results
,
gen
)
}
return
results
,
total
,
rows
.
Err
()
}
func
(
r
*
soraGenerationRepository
)
CountByUserAndStatus
(
ctx
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
{
if
len
(
statuses
)
==
0
{
return
0
,
nil
}
placeholders
:=
make
([]
string
,
len
(
statuses
))
args
:=
[]
any
{
userID
}
for
i
,
s
:=
range
statuses
{
placeholders
[
i
]
=
fmt
.
Sprintf
(
"$%d"
,
i
+
2
)
args
=
append
(
args
,
s
)
}
var
count
int64
query
:=
fmt
.
Sprintf
(
"SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)"
,
strings
.
Join
(
placeholders
,
","
))
err
:=
r
.
sql
.
QueryRowContext
(
ctx
,
query
,
args
...
)
.
Scan
(
&
count
)
return
count
,
err
}
backend/internal/repository/user_repo.go
View file @
62e80c60
...
@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
...
@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance
(
userIn
.
Balance
)
.
SetBalance
(
userIn
.
Balance
)
.
SetConcurrency
(
userIn
.
Concurrency
)
.
SetConcurrency
(
userIn
.
Concurrency
)
.
SetStatus
(
userIn
.
Status
)
.
SetStatus
(
userIn
.
Status
)
.
SetSoraStorageQuotaBytes
(
userIn
.
SoraStorageQuotaBytes
)
.
Save
(
ctx
)
Save
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
...
@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
...
@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance
(
userIn
.
Balance
)
.
SetBalance
(
userIn
.
Balance
)
.
SetConcurrency
(
userIn
.
Concurrency
)
.
SetConcurrency
(
userIn
.
Concurrency
)
.
SetStatus
(
userIn
.
Status
)
.
SetStatus
(
userIn
.
Status
)
.
SetSoraStorageQuotaBytes
(
userIn
.
SoraStorageQuotaBytes
)
.
SetSoraStorageUsedBytes
(
userIn
.
SoraStorageUsedBytes
)
.
Save
(
ctx
)
Save
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
service
.
ErrEmailExists
)
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
service
.
ErrEmailExists
)
...
@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
...
@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return
nil
return
nil
}
}
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func
(
r
*
userRepository
)
AddSoraStorageUsageWithQuota
(
ctx
context
.
Context
,
userID
int64
,
deltaBytes
int64
,
effectiveQuota
int64
)
(
int64
,
error
)
{
if
deltaBytes
<=
0
{
user
,
err
:=
r
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
0
,
err
}
return
user
.
SoraStorageUsedBytes
,
nil
}
var
newUsed
int64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
`
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`
,
[]
any
{
userID
,
deltaBytes
,
effectiveQuota
},
&
newUsed
)
if
err
==
nil
{
return
newUsed
,
nil
}
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
// 区分用户不存在和配额冲突
exists
,
existsErr
:=
r
.
client
.
User
.
Query
()
.
Where
(
dbuser
.
IDEQ
(
userID
))
.
Exist
(
ctx
)
if
existsErr
!=
nil
{
return
0
,
existsErr
}
if
!
exists
{
return
0
,
service
.
ErrUserNotFound
}
return
0
,
service
.
ErrSoraStorageQuotaExceeded
}
return
0
,
err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func
(
r
*
userRepository
)
ReleaseSoraStorageUsageAtomic
(
ctx
context
.
Context
,
userID
int64
,
deltaBytes
int64
)
(
int64
,
error
)
{
if
deltaBytes
<=
0
{
user
,
err
:=
r
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
0
,
err
}
return
user
.
SoraStorageUsedBytes
,
nil
}
var
newUsed
int64
err
:=
scanSingleRow
(
ctx
,
r
.
sql
,
`
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`
,
[]
any
{
userID
,
deltaBytes
},
&
newUsed
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
0
,
service
.
ErrUserNotFound
}
return
0
,
err
}
return
newUsed
,
nil
}
func
(
r
*
userRepository
)
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
{
func
(
r
*
userRepository
)
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
{
return
r
.
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
email
))
.
Exist
(
ctx
)
return
r
.
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
email
))
.
Exist
(
ctx
)
}
}
...
...
backend/internal/repository/wire.go
View file @
62e80c60
...
@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
...
@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository
,
NewAPIKeyRepository
,
NewGroupRepository
,
NewGroupRepository
,
NewAccountRepository
,
NewAccountRepository
,
NewSoraAccountRepository
,
// Sora 账号扩展表仓储
NewScheduledTestPlanRepository
,
// 定时测试计划仓储
NewScheduledTestPlanRepository
,
// 定时测试计划仓储
NewScheduledTestResultRepository
,
// 定时测试结果仓储
NewScheduledTestResultRepository
,
// 定时测试结果仓储
NewProxyRepository
,
NewProxyRepository
,
...
...
backend/internal/server/middleware/security_headers.go
View file @
62e80c60
...
@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
...
@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
return
strings
.
HasPrefix
(
path
,
"/v1/"
)
||
return
strings
.
HasPrefix
(
path
,
"/v1/"
)
||
strings
.
HasPrefix
(
path
,
"/v1beta/"
)
||
strings
.
HasPrefix
(
path
,
"/v1beta/"
)
||
strings
.
HasPrefix
(
path
,
"/antigravity/"
)
||
strings
.
HasPrefix
(
path
,
"/antigravity/"
)
||
strings
.
HasPrefix
(
path
,
"/sora/"
)
||
strings
.
HasPrefix
(
path
,
"/responses"
)
strings
.
HasPrefix
(
path
,
"/responses"
)
}
}
...
...
backend/internal/server/router.go
View file @
62e80c60
...
@@ -109,7 +109,6 @@ func registerRoutes(
...
@@ -109,7 +109,6 @@ func registerRoutes(
// 注册各模块路由
// 注册各模块路由
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
,
settingService
)
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
,
settingService
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterSoraClientRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
)
}
}
backend/internal/server/routes/admin.go
View file @
62e80c60
...
@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
...
@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
registerOpenAIOAuthRoutes
(
admin
,
h
)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes
(
admin
,
h
)
// Gemini OAuth
// Gemini OAuth
registerGeminiOAuthRoutes
(
admin
,
h
)
registerGeminiOAuthRoutes
(
admin
,
h
)
...
@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
}
}
func
registerSoraOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
sora
:=
admin
.
Group
(
"/sora"
)
{
sora
.
POST
(
"/generate-auth-url"
,
h
.
Admin
.
OpenAIOAuth
.
GenerateAuthURL
)
sora
.
POST
(
"/exchange-code"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeCode
)
sora
.
POST
(
"/refresh-token"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/st2at"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeSoraSessionToken
)
sora
.
POST
(
"/rt2at"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/accounts/:id/refresh"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshAccountToken
)
sora
.
POST
(
"/create-from-oauth"
,
h
.
Admin
.
OpenAIOAuth
.
CreateAccountFromOAuth
)
}
}
func
registerGeminiOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
func
registerGeminiOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
gemini
:=
admin
.
Group
(
"/gemini"
)
gemini
:=
admin
.
Group
(
"/gemini"
)
{
{
...
...
backend/internal/server/routes/gateway.go
View file @
62e80c60
...
@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
...
@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
{
)
{
bodyLimit
:=
middleware
.
RequestBodyLimit
(
cfg
.
Gateway
.
MaxBodySize
)
bodyLimit
:=
middleware
.
RequestBodyLimit
(
cfg
.
Gateway
.
MaxBodySize
)
soraMaxBodySize
:=
cfg
.
Gateway
.
SoraMaxBodySize
if
soraMaxBodySize
<=
0
{
soraMaxBodySize
=
cfg
.
Gateway
.
MaxBodySize
}
soraBodyLimit
:=
middleware
.
RequestBodyLimit
(
soraMaxBodySize
)
clientRequestID
:=
middleware
.
ClientRequestID
()
clientRequestID
:=
middleware
.
ClientRequestID
()
opsErrorLogger
:=
handler
.
OpsErrorLoggerMiddleware
(
opsService
)
opsErrorLogger
:=
handler
.
OpsErrorLoggerMiddleware
(
opsService
)
endpointNorm
:=
handler
.
InboundEndpointMiddleware
()
endpointNorm
:=
handler
.
InboundEndpointMiddleware
()
...
@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
...
@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta
.
POST
(
"/models/*modelAction"
,
h
.
Gateway
.
GeminiV1BetaModels
)
antigravityV1Beta
.
POST
(
"/models/*modelAction"
,
h
.
Gateway
.
GeminiV1BetaModels
)
}
}
// Sora 专用路由(强制使用 sora 平台)
soraV1
:=
r
.
Group
(
"/sora/v1"
)
soraV1
.
Use
(
soraBodyLimit
)
soraV1
.
Use
(
clientRequestID
)
soraV1
.
Use
(
opsErrorLogger
)
soraV1
.
Use
(
endpointNorm
)
soraV1
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformSora
))
soraV1
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
soraV1
.
Use
(
requireGroupAnthropic
)
{
soraV1
.
POST
(
"/chat/completions"
,
h
.
SoraGateway
.
ChatCompletions
)
soraV1
.
GET
(
"/models"
,
h
.
Gateway
.
Models
)
}
// Sora 媒体代理(可选 API Key 验证)
if
cfg
.
Gateway
.
SoraMediaRequireAPIKey
{
r
.
GET
(
"/sora/media/*filepath"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
SoraGateway
.
MediaProxy
)
}
else
{
r
.
GET
(
"/sora/media/*filepath"
,
h
.
SoraGateway
.
MediaProxy
)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r
.
GET
(
"/sora/media-signed/*filepath"
,
h
.
SoraGateway
.
MediaProxySigned
)
}
}
// getGroupPlatform extracts the group platform from the API Key stored in context.
// getGroupPlatform extracts the group platform from the API Key stored in context.
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment