Unverified Commit 9d795061 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #682 from mt21625457/pr/all-code-sync-20260228

feat(openai-ws): support websocket mode v2, optimize relay performance, enhance sora
parents bfc7b339 1d1fc019
......@@ -43,6 +43,7 @@ type SoraVideoRequest struct {
Frames int
Model string
Size string
VideoCount int
MediaID string
RemixTargetID string
CameoIDs []string
......
......@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
......@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
soraClient SoraClient
mediaStorage *SoraMediaStorage
rateLimitService *RateLimitService
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
cfg *config.Config
}
......@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
soraClient: soraClient,
mediaStorage: mediaStorage,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
......@@ -115,6 +116,15 @@ func NewSoraGatewayService(
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
if s.httpUpstream == nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
}
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
}
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
......@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
taskID := ""
var err error
videoCount := parseSoraVideoCount(reqBody)
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
......@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
VideoCount: videoCount,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
......@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
}
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
if storeErr != nil {
// 存储失败时降级使用原始 URL,不中断用户请求
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
} else {
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
......@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
}
}
func parseSoraVideoCount(body map[string]any) int {
if body == nil {
return 1
}
keys := []string{"video_count", "videos", "n_variants"}
for _, key := range keys {
count := parseIntWithDefault(body, key, 0)
if count > 0 {
return clampInt(count, 1, 3)
}
}
return 1
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
......@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
return def
}
func parseIntWithDefault(body map[string]any, key string, def int) int {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float64:
return int(typed)
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return parsed
}
}
return def
}
func clampInt(v, minVal, maxVal int) int {
if v < minVal {
return minVal
}
if v > maxVal {
return maxVal
}
return v
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
......@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora",
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
accountID,
model,
upstreamErr.StatusCode,
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
strings.TrimSpace(upstreamErr.Message),
truncateForLog(upstreamErr.Body, 1024),
)
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
......
......@@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 3, client.videoReq.VideoCount)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
......@@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
func TestParseSoraVideoCount(t *testing.T) {
require.Equal(t, 1, parseSoraVideoCount(nil))
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
}
package service
import (
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type SoraGeneration struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
MediaType string `json:"media_type"` // video / image
Status string `json:"status"` // pending / generating / completed / failed / cancelled
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
FileSizeBytes int64 `json:"file_size_bytes"`
StorageType string `json:"storage_type"` // s3 / local / upstream / none
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
UpstreamTaskID string `json:"upstream_task_id"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const (
SoraGenStatusPending = "pending"
SoraGenStatusGenerating = "generating"
SoraGenStatusCompleted = "completed"
SoraGenStatusFailed = "failed"
SoraGenStatusCancelled = "cancelled"
)
// Sora 存储类型常量
const (
SoraStorageTypeS3 = "s3"
SoraStorageTypeLocal = "local"
SoraStorageTypeUpstream = "upstream"
SoraStorageTypeNone = "none"
)
// SoraGenerationListParams 查询生成记录的参数。
type SoraGenerationListParams struct {
UserID int64
Status string // 可选筛选
StorageType string // 可选筛选
MediaType string // 可选筛选
Page int
PageSize int
}
// SoraGenerationRepository 生成记录持久化接口。
type SoraGenerationRepository interface {
Create(ctx context.Context, gen *SoraGeneration) error
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
Update(ctx context.Context, gen *SoraGeneration) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment