Unverified Commit 9d795061 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #682 from mt21625457/pr/all-code-sync-20260228

feat(openai-ws): support websocket mode v2, optimize relay performance, enhance sora
parents bfc7b339 1d1fc019
......@@ -43,6 +43,7 @@ type SoraVideoRequest struct {
Frames int
Model string
Size string
VideoCount int
MediaID string
RemixTargetID string
CameoIDs []string
......
......@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
......@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
soraClient SoraClient
mediaStorage *SoraMediaStorage
rateLimitService *RateLimitService
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
cfg *config.Config
}
......@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
soraClient: soraClient,
mediaStorage: mediaStorage,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
......@@ -115,6 +116,15 @@ func NewSoraGatewayService(
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
if s.httpUpstream == nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
}
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
}
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
......@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
taskID := ""
var err error
videoCount := parseSoraVideoCount(reqBody)
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
......@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
VideoCount: videoCount,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
......@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
}
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
if storeErr != nil {
// 存储失败时降级使用原始 URL,不中断用户请求
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
} else {
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
......@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
}
}
func parseSoraVideoCount(body map[string]any) int {
if body == nil {
return 1
}
keys := []string{"video_count", "videos", "n_variants"}
for _, key := range keys {
count := parseIntWithDefault(body, key, 0)
if count > 0 {
return clampInt(count, 1, 3)
}
}
return 1
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
......@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
return def
}
func parseIntWithDefault(body map[string]any, key string, def int) int {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float64:
return int(typed)
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return parsed
}
}
return def
}
func clampInt(v, minVal, maxVal int) int {
if v < minVal {
return minVal
}
if v > maxVal {
return maxVal
}
return v
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
......@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora",
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
accountID,
model,
upstreamErr.StatusCode,
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
strings.TrimSpace(upstreamErr.Message),
truncateForLog(upstreamErr.Body, 1024),
)
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
......
......@@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 3, client.videoReq.VideoCount)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
......@@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
func TestParseSoraVideoCount(t *testing.T) {
require.Equal(t, 1, parseSoraVideoCount(nil))
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
}
package service
import (
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type SoraGeneration struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
MediaType string `json:"media_type"` // video / image
Status string `json:"status"` // pending / generating / completed / failed / cancelled
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
FileSizeBytes int64 `json:"file_size_bytes"`
StorageType string `json:"storage_type"` // s3 / local / upstream / none
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
UpstreamTaskID string `json:"upstream_task_id"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const (
SoraGenStatusPending = "pending"
SoraGenStatusGenerating = "generating"
SoraGenStatusCompleted = "completed"
SoraGenStatusFailed = "failed"
SoraGenStatusCancelled = "cancelled"
)
// Sora 存储类型常量
const (
SoraStorageTypeS3 = "s3"
SoraStorageTypeLocal = "local"
SoraStorageTypeUpstream = "upstream"
SoraStorageTypeNone = "none"
)
// SoraGenerationListParams 查询生成记录的参数。
type SoraGenerationListParams struct {
UserID int64
Status string // 可选筛选
StorageType string // 可选筛选
MediaType string // 可选筛选
Page int
PageSize int
}
// SoraGenerationRepository 生成记录持久化接口。
type SoraGenerationRepository interface {
Create(ctx context.Context, gen *SoraGeneration) error
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
Update(ctx context.Context, gen *SoraGeneration) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
}
package service
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
)
const soraGenerationActiveLimit = 3
type soraGenerationRepoAtomicCreator interface {
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
}
type soraGenerationRepoConditionalUpdater interface {
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type SoraGenerationService struct {
genRepo SoraGenerationRepository
s3Storage *SoraS3Storage
quotaService *SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func NewSoraGenerationService(
genRepo SoraGenerationRepository,
s3Storage *SoraS3Storage,
quotaService *SoraQuotaService,
) *SoraGenerationService {
return &SoraGenerationService{
genRepo: genRepo,
s3Storage: s3Storage,
quotaService: quotaService,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
gen := &SoraGeneration{
UserID: userID,
APIKeyID: apiKeyID,
Model: model,
Prompt: prompt,
MediaType: mediaType,
Status: SoraGenStatusPending,
StorageType: SoraStorageTypeNone,
}
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
if err := atomicCreator.CreatePendingWithLimit(
ctx,
gen,
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
soraGenerationActiveLimit,
); err != nil {
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
return nil, err
}
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
if err := s.genRepo.Create(ctx, gen); err != nil {
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
// MarkGenerating 标记为生成中。
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusGenerating
gen.UpstreamTaskID = upstreamTaskID
return s.genRepo.Update(ctx, gen)
}
// MarkCompleted 标记为已完成。
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusCompleted
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkFailed 标记为失败。
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusFailed
gen.ErrorMessage = errMsg
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkCancelled 标记为已取消。
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationNotActive
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationNotActive
}
gen.Status = SoraGenStatusCancelled
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
func (s *SoraGenerationService) UpdateStorageForCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusCompleted {
return ErrSoraGenerationStateConflict
}
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
return s.genRepo.Update(ctx, gen)
}
// GetByID 获取记录详情(含权限校验)。
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if gen.UserID != userID {
return nil, fmt.Errorf("无权访问此生成记录")
}
return gen, nil
}
// List 查询生成记录列表(分页 + 筛选)。
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
return s.genRepo.List(ctx, params)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.UserID != userID {
return fmt.Errorf("无权删除此生成记录")
}
// 清理 S3 文件
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
}
}
// 释放配额(S3/本地均释放)
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
}
}
return s.genRepo.Delete(ctx, id)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
return nil
}
if len(gen.S3ObjectKeys) == 0 {
return nil
}
urls := make([]string, len(gen.S3ObjectKeys))
var wg sync.WaitGroup
var firstErr error
var errMu sync.Mutex
for idx, key := range gen.S3ObjectKeys {
wg.Add(1)
go func(i int, objectKey string) {
defer wg.Done()
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
if err != nil {
errMu.Lock()
if firstErr == nil {
firstErr = err
}
errMu.Unlock()
return
}
urls[i] = url
}(idx, key)
}
wg.Wait()
if firstErr != nil {
return firstErr
}
gen.MediaURL = urls[0]
gen.MediaURLs = urls
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment