Commit eb2dce92 authored by 陈曦's avatar 陈曦
Browse files

升级v1.0.8 解决冲突

parents 7b83d6e7 339d906e
package service
import (
"context"
"fmt"
"io"
"net/http"
"path"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
type SoraS3Storage struct {
settingService *SettingService
mu sync.RWMutex
client *s3.Client
cfg *SoraS3Settings // 上次加载的配置快照
healthCheckedAt time.Time
healthErr error
healthTTL time.Duration
}
const defaultSoraS3HealthTTL = 30 * time.Second
// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
type UpstreamDownloadError struct {
StatusCode int
}
func (e *UpstreamDownloadError) Error() string {
if e == nil {
return "upstream download failed"
}
return fmt.Sprintf("upstream returned %d", e.StatusCode)
}
// NewSoraS3Storage 创建 S3 存储服务实例。
func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage {
return &SoraS3Storage{
settingService: settingService,
healthTTL: defaultSoraS3HealthTTL,
}
}
// Enabled 返回 S3 存储是否已启用且配置有效。
func (s *SoraS3Storage) Enabled(ctx context.Context) bool {
cfg, err := s.getConfig(ctx)
if err != nil || cfg == nil {
return false
}
return cfg.Enabled && cfg.Bucket != ""
}
// getConfig 获取当前 S3 配置(从 settings 表读取)。
func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) {
if s.settingService == nil {
return nil, fmt.Errorf("setting service not available")
}
return s.settingService.GetSoraS3Settings(ctx)
}
// getClient 获取或初始化 S3 客户端(带缓存)。
// 配置变更时调用 RefreshClient 清除缓存。
func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
s.mu.RLock()
if s.client != nil && s.cfg != nil {
client, cfg := s.client, s.cfg
s.mu.RUnlock()
return client, cfg, nil
}
s.mu.RUnlock()
return s.initClient(ctx)
}
func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
s.mu.Lock()
defer s.mu.Unlock()
// 双重检查
if s.client != nil && s.cfg != nil {
return s.client, s.cfg, nil
}
cfg, err := s.getConfig(ctx)
if err != nil {
return nil, nil, fmt.Errorf("load s3 config: %w", err)
}
if !cfg.Enabled {
return nil, nil, fmt.Errorf("sora s3 storage is disabled")
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required")
}
client, region, err := buildSoraS3Client(ctx, cfg)
if err != nil {
return nil, nil, err
}
s.client = client
s.cfg = cfg
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region)
return client, cfg, nil
}
// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
// 应在系统设置中 S3 配置变更时调用。
func (s *SoraS3Storage) RefreshClient() {
s.mu.Lock()
defer s.mu.Unlock()
s.client = nil
s.cfg = nil
s.healthCheckedAt = time.Time{}
s.healthErr = nil
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化")
}
// TestConnection 测试 S3 连接(HeadBucket)。
func (s *SoraS3Storage) TestConnection(ctx context.Context) error {
client, cfg, err := s.getClient(ctx)
if err != nil {
return err
}
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &cfg.Bucket,
})
if err != nil {
return fmt.Errorf("s3 HeadBucket failed: %w", err)
}
return nil
}
// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。
func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool {
if s == nil {
return false
}
now := time.Now()
s.mu.RLock()
lastCheck := s.healthCheckedAt
lastErr := s.healthErr
ttl := s.healthTTL
s.mu.RUnlock()
if ttl <= 0 {
ttl = defaultSoraS3HealthTTL
}
if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl {
return lastErr == nil
}
err := s.TestConnection(ctx)
s.mu.Lock()
s.healthCheckedAt = time.Now()
s.healthErr = err
s.mu.Unlock()
return err == nil
}
// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error {
if cfg == nil {
return fmt.Errorf("s3 config is required")
}
if !cfg.Enabled {
return fmt.Errorf("sora s3 storage is disabled")
}
if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required")
}
client, _, err := buildSoraS3Client(ctx, cfg)
if err != nil {
return err
}
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &cfg.Bucket,
})
if err != nil {
return fmt.Errorf("s3 HeadBucket failed: %w", err)
}
return nil
}
// GenerateObjectKey 生成 S3 object key。
// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string {
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
datePath := time.Now().Format("2006/01/02")
key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext)
if prefix != "" {
prefix = strings.TrimRight(prefix, "/") + "/"
key = prefix + key
}
return key
}
// UploadFromURL 从上游 URL 下载并流式上传到 S3。
// 返回 S3 object key。
func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) {
client, cfg, err := s.getClient(ctx)
if err != nil {
return "", 0, err
}
// 下载源文件
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
if err != nil {
return "", 0, fmt.Errorf("create download request: %w", err)
}
httpClient := &http.Client{Timeout: 5 * time.Minute}
resp, err := httpClient.Do(req)
if err != nil {
return "", 0, fmt.Errorf("download from upstream: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode}
}
// 推断文件扩展名
ext := fileExtFromURL(sourceURL)
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
if ext == "" {
ext = ".bin"
}
objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext)
// 检测 Content-Type
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream"
}
reader, writer := io.Pipe()
uploadErrCh := make(chan error, 1)
go func() {
defer close(uploadErrCh)
input := &s3.PutObjectInput{
Bucket: &cfg.Bucket,
Key: &objectKey,
Body: reader,
ContentType: &contentType,
}
if resp.ContentLength >= 0 {
input.ContentLength = &resp.ContentLength
}
_, uploadErr := client.PutObject(ctx, input)
uploadErrCh <- uploadErr
}()
written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024))
_ = writer.CloseWithError(copyErr)
uploadErr := <-uploadErrCh
if copyErr != nil {
return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr)
}
if uploadErr != nil {
return "", 0, fmt.Errorf("s3 upload: %w", uploadErr)
}
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written)
return objectKey, written, nil
}
func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) {
if cfg == nil {
return nil, "", fmt.Errorf("s3 config is required")
}
region := cfg.Region
if region == "" {
region = "us-east-1"
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
)
if err != nil {
return nil, "", fmt.Errorf("load aws config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = &cfg.Endpoint
}
if cfg.ForcePathStyle {
o.UsePathStyle = true
}
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
// 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
return client, region, nil
}
// DeleteObjects 删除一组 S3 object(遍历逐一删除)。
func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error {
if len(objectKeys) == 0 {
return nil
}
client, cfg, err := s.getClient(ctx)
if err != nil {
return err
}
var lastErr error
for _, key := range objectKeys {
k := key
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &cfg.Bucket,
Key: &k,
})
if err != nil {
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err)
lastErr = err
}
}
return lastErr
}
// GetAccessURL 获取 S3 文件的访问 URL。
// CDN URL 优先,否则生成 24h 预签名 URL。
func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) {
_, cfg, err := s.getClient(ctx)
if err != nil {
return "", err
}
// CDN URL 优先
if cfg.CDNURL != "" {
cdnBase := strings.TrimRight(cfg.CDNURL, "/")
return cdnBase + "/" + objectKey, nil
}
// 生成 24h 预签名 URL
return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour)
}
// GeneratePresignedURL 生成预签名 URL。
func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) {
client, cfg, err := s.getClient(ctx)
if err != nil {
return "", err
}
presignClient := s3.NewPresignClient(client)
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &cfg.Bucket,
Key: &objectKey,
}, s3.WithPresignExpires(ttl))
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return result.URL, nil
}
// GetMediaType 从 object key 推断媒体类型(image/video)。
func GetMediaTypeFromKey(objectKey string) string {
ext := strings.ToLower(path.Ext(objectKey))
switch ext {
case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
return "video"
default:
return "image"
}
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ==================== RefreshClient ====================
func TestRefreshClient(t *testing.T) {
s := newS3StorageWithCDN("https://cdn.example.com")
require.NotNil(t, s.client)
require.NotNil(t, s.cfg)
s.RefreshClient()
require.Nil(t, s.client)
require.Nil(t, s.cfg)
}
func TestRefreshClient_AlreadyNil(t *testing.T) {
s := NewSoraS3Storage(nil)
s.RefreshClient() // 不应 panic
require.Nil(t, s.client)
require.Nil(t, s.cfg)
}
// ==================== GetMediaTypeFromKey ====================
func TestGetMediaTypeFromKey_VideoExtensions(t *testing.T) {
for _, ext := range []string{".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv"} {
require.Equal(t, "video", GetMediaTypeFromKey("path/to/file"+ext), "ext=%s", ext)
}
}
func TestGetMediaTypeFromKey_VideoUpperCase(t *testing.T) {
require.Equal(t, "video", GetMediaTypeFromKey("file.MP4"))
require.Equal(t, "video", GetMediaTypeFromKey("file.MOV"))
}
func TestGetMediaTypeFromKey_ImageExtensions(t *testing.T) {
require.Equal(t, "image", GetMediaTypeFromKey("file.png"))
require.Equal(t, "image", GetMediaTypeFromKey("file.jpg"))
require.Equal(t, "image", GetMediaTypeFromKey("file.jpeg"))
require.Equal(t, "image", GetMediaTypeFromKey("file.gif"))
require.Equal(t, "image", GetMediaTypeFromKey("file.webp"))
}
func TestGetMediaTypeFromKey_NoExtension(t *testing.T) {
require.Equal(t, "image", GetMediaTypeFromKey("file"))
require.Equal(t, "image", GetMediaTypeFromKey("path/to/file"))
}
func TestGetMediaTypeFromKey_UnknownExtension(t *testing.T) {
require.Equal(t, "image", GetMediaTypeFromKey("file.bin"))
require.Equal(t, "image", GetMediaTypeFromKey("file.xyz"))
}
// ==================== Enabled ====================
func TestEnabled_NilSettingService(t *testing.T) {
s := NewSoraS3Storage(nil)
require.False(t, s.Enabled(context.Background()))
}
func TestEnabled_ConfigDisabled(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "false",
SettingKeySoraS3Bucket: "test-bucket",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
require.False(t, s.Enabled(context.Background()))
}
func TestEnabled_ConfigEnabledWithBucket(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "my-bucket",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
require.True(t, s.Enabled(context.Background()))
}
func TestEnabled_ConfigEnabledEmptyBucket(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
require.False(t, s.Enabled(context.Background()))
}
// ==================== initClient ====================
func TestInitClient_Disabled(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "false",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
_, _, err := s.getClient(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "disabled")
}
func TestInitClient_IncompleteConfig(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "test-bucket",
// 缺少 access_key_id 和 secret_access_key
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
_, _, err := s.getClient(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "incomplete")
}
func TestInitClient_DefaultRegion(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "test-bucket",
SettingKeySoraS3AccessKeyID: "AKID",
SettingKeySoraS3SecretAccessKey: "SECRET",
// Region 为空 → 默认 us-east-1
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
client, cfg, err := s.getClient(context.Background())
require.NoError(t, err)
require.NotNil(t, client)
require.Equal(t, "test-bucket", cfg.Bucket)
}
func TestInitClient_DoubleCheck(t *testing.T) {
// 验证双重检查锁定:第二次 getClient 命中缓存
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "test-bucket",
SettingKeySoraS3AccessKeyID: "AKID",
SettingKeySoraS3SecretAccessKey: "SECRET",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
client1, _, err1 := s.getClient(context.Background())
require.NoError(t, err1)
client2, _, err2 := s.getClient(context.Background())
require.NoError(t, err2)
require.Equal(t, client1, client2) // 同一客户端实例
}
func TestInitClient_NilSettingService(t *testing.T) {
s := NewSoraS3Storage(nil)
_, _, err := s.getClient(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "setting service not available")
}
// ==================== GenerateObjectKey ====================
func TestGenerateObjectKey_ExtWithoutDot(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("", 1, "mp4")
require.Contains(t, key, ".mp4")
require.True(t, len(key) > 0)
}
func TestGenerateObjectKey_ExtWithDot(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("", 1, ".mp4")
require.Contains(t, key, ".mp4")
// 不应出现 ..mp4
require.NotContains(t, key, "..mp4")
}
func TestGenerateObjectKey_WithPrefix(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("uploads/", 42, ".png")
require.True(t, len(key) > 0)
require.Contains(t, key, "uploads/sora/42/")
}
func TestGenerateObjectKey_PrefixWithoutTrailingSlash(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("uploads", 42, ".png")
require.Contains(t, key, "uploads/sora/42/")
}
// ==================== GeneratePresignedURL ====================
func TestGeneratePresignedURL_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil) // settingService=nil → getClient 失败
_, err := s.GeneratePresignedURL(context.Background(), "key", 3600)
require.Error(t, err)
}
// ==================== GetAccessURL ====================
func TestGetAccessURL_CDN(t *testing.T) {
s := newS3StorageWithCDN("https://cdn.example.com")
url, err := s.GetAccessURL(context.Background(), "sora/1/2024/01/01/video.mp4")
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", url)
}
func TestGetAccessURL_CDNTrailingSlash(t *testing.T) {
s := newS3StorageWithCDN("https://cdn.example.com/")
url, err := s.GetAccessURL(context.Background(), "key.mp4")
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/key.mp4", url)
}
func TestGetAccessURL_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
_, err := s.GetAccessURL(context.Background(), "key")
require.Error(t, err)
}
// ==================== TestConnection ====================
func TestTestConnection_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.TestConnection(context.Background())
require.Error(t, err)
}
// ==================== UploadFromURL ====================
func TestUploadFromURL_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
_, _, err := s.UploadFromURL(context.Background(), 1, "https://example.com/file.mp4")
require.Error(t, err)
}
// ==================== DeleteObjects ====================
func TestDeleteObjects_EmptyKeys(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.DeleteObjects(context.Background(), []string{})
require.NoError(t, err) // 空列表直接返回
}
func TestDeleteObjects_NilKeys(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.DeleteObjects(context.Background(), nil)
require.NoError(t, err) // nil 列表直接返回
}
func TestDeleteObjects_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.DeleteObjects(context.Background(), []string{"key1", "key2"})
require.Error(t, err)
}
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/DouDOU-start/go-sora2api/sora"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/tidwall/gjson"
)
// SoraSDKClient 基于 go-sora2api SDK 的 Sora 客户端实现。
// 它实现了 SoraClient 接口,用 SDK 替代原有的自建 HTTP/PoW/TLS 指纹逻辑。
type SoraSDKClient struct {
cfg *config.Config
httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
// 每个 proxyURL 对应一个 SDK 客户端实例
sdkClients sync.Map // key: proxyURL (string), value: *sora.Client
}
// NewSoraSDKClient 创建基于 SDK 的 Sora 客户端
func NewSoraSDKClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraSDKClient {
return &SoraSDKClient{
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
}
}
// SetAccountRepositories 设置账号和 Sora 扩展仓库(用于 token 持久化)
func (c *SoraSDKClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
if c == nil {
return
}
c.accountRepo = accountRepo
c.soraAccountRepo = soraAccountRepo
}
// Enabled 判断是否启用 Sora
func (c *SoraSDKClient) Enabled() bool {
if c == nil || c.cfg == nil {
return false
}
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
}
// PreflightCheck 在创建任务前执行账号能力预检。
// 当前仅对视频模型执行预检,用于提前识别额度耗尽或能力缺失。
func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
if modelCfg.Type != "video" {
return nil
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
balance, err := sdkClient.GetCreditBalance(ctx, token)
if err != nil {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora_sdk",
"[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s",
accountID,
requestedModel,
logredact.RedactText(err.Error()),
)
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
}
}
if balance.RateLimitReached || balance.RemainingCount <= 0 {
msg := "当前账号 Sora2 可用配额不足"
if requestedModel != "" {
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: msg,
}
}
return nil
}
func (c *SoraSDKClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
if len(data) == 0 {
return "", errors.New("empty image data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
if filename == "" {
filename = "image.png"
}
mediaID, err := sdkClient.UploadImage(ctx, token, data, filename)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return mediaID, nil
}
func (c *SoraSDKClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
var taskID string
if strings.TrimSpace(req.MediaID) != "" {
taskID, err = sdkClient.CreateImageTaskWithImage(ctx, token, sentinel, req.Prompt, req.Width, req.Height, req.MediaID)
} else {
taskID, err = sdkClient.CreateImageTask(ctx, token, sentinel, req.Prompt, req.Width, req.Height)
}
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
videoCount := req.VideoCount
if videoCount <= 0 {
videoCount = 1
}
if videoCount > 3 {
videoCount = 3
}
// Remix 模式
if strings.TrimSpace(req.RemixTargetID) != "" {
if videoCount > 1 {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
c.debugLogf("video_count_ignored_for_remix account_id=%d count=%d", accountID, videoCount)
}
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
// 普通视频(文生视频或图生视频)
var taskID string
if videoCount <= 1 {
taskID, err = sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
} else {
taskID, err = c.createVideoTaskWithVariants(ctx, account, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, videoCount)
}
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) createVideoTaskWithVariants(
ctx context.Context,
account *Account,
accessToken string,
sentinelToken string,
prompt string,
orientation string,
nFrames int,
model string,
size string,
mediaID string,
videoCount int,
) (string, error) {
inpaintItems := make([]any, 0, 1)
if strings.TrimSpace(mediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": mediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": prompt,
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"n_variants": videoCount,
"model": model,
"inpaint_items": inpaintItems,
"style_id": nil,
}
raw, err := c.doSoraBackendJSON(ctx, account, http.MethodPost, "/nf/create", accessToken, sentinelToken, payload)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(raw, "id").String())
if taskID == "" {
return "", errors.New("create video task response missing id")
}
return taskID, nil
}
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
taskID, err := sdkClient.CreateStoryboardTask(ctx, token, sentinel, req.Prompt, orientation, nFrames, req.MediaID, "")
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty video data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
cameoID, err := sdkClient.UploadCharacterVideo(ctx, token, data)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return cameoID, nil
}
func (c *SoraSDKClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
status, err := sdkClient.GetCameoStatus(ctx, token, cameoID)
if err != nil {
return nil, c.wrapSDKError(err, account)
}
return &SoraCameoStatus{
Status: status.Status,
DisplayNameHint: status.DisplayNameHint,
UsernameHint: status.UsernameHint,
ProfileAssetURL: status.ProfileAssetURL,
}, nil
}
func (c *SoraSDKClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
data, err := sdkClient.DownloadCharacterImage(ctx, imageURL)
if err != nil {
return nil, c.wrapSDKError(err, account)
}
return data, nil
}
func (c *SoraSDKClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty character image")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
assetPointer, err := sdkClient.UploadCharacterImage(ctx, token, data)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return assetPointer, nil
}
func (c *SoraSDKClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
characterID, err := sdkClient.FinalizeCharacter(ctx, token, req.CameoID, req.Username, req.DisplayName, req.ProfileAssetPointer)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return characterID, nil
}
func (c *SoraSDKClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
if err := sdkClient.SetCharacterPublic(ctx, token, cameoID); err != nil {
return c.wrapSDKError(err, account)
}
return nil
}
func (c *SoraSDKClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
if err := sdkClient.DeleteCharacter(ctx, token, characterID); err != nil {
return c.wrapSDKError(err, account)
}
return nil
}
func (c *SoraSDKClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
postID, err := sdkClient.PublishVideo(ctx, token, sentinel, generationID)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return postID, nil
}
func (c *SoraSDKClient) DeletePost(ctx context.Context, account *Account, postID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
if err := sdkClient.DeletePost(ctx, token, postID); err != nil {
return c.wrapSDKError(err, account)
}
return nil
}
// GetWatermarkFreeURLCustom 使用自定义第三方解析服务获取去水印链接。
// SDK 不涉及此功能,保留自建实现。
func (c *SoraSDKClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
if parseURL == "" {
return "", errors.New("custom parse url is required")
}
if strings.TrimSpace(parseToken) == "" {
return "", errors.New("custom parse token is required")
}
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
payload := map[string]any{
"url": shareURL,
"token": strings.TrimSpace(parseToken),
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
}
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
if downloadLink == "" {
return "", errors.New("custom parse response missing download_link")
}
return downloadLink, nil
}
func (c *SoraSDKClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
if strings.TrimSpace(expansionLevel) == "" {
expansionLevel = "medium"
}
if durationS <= 0 {
durationS = 10
}
enhanced, err := sdkClient.EnhancePrompt(ctx, token, prompt, expansionLevel, durationS)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return enhanced, nil
}
func (c *SoraSDKClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
result := sdkClient.QueryImageTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second))
if result.Err != nil {
return &SoraImageTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: result.Err.Error(),
}, nil
}
if result.Done && result.ImageURL != "" {
return &SoraImageTaskStatus{
ID: taskID,
Status: "succeeded",
URLs: []string{result.ImageURL},
}, nil
}
status := result.Progress.Status
if status == "" {
status = "processing"
}
return &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: float64(result.Progress.Percent) / 100.0,
}, nil
}
func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
// 先查询 pending 列表
result := sdkClient.QueryVideoTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second), 0)
if result.Err != nil {
return &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: result.Err.Error(),
}, nil
}
if !result.Done {
return &SoraVideoTaskStatus{
ID: taskID,
Status: result.Progress.Status,
ProgressPct: result.Progress.Percent,
}, nil
}
// 任务不在 pending 中,查询 drafts 获取下载链接
downloadURLs, err := c.getVideoTaskDownloadURLs(ctx, account, token, taskID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
return &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: errMsg,
}, nil
}
// 可能还在处理中
return &SoraVideoTaskStatus{
ID: taskID,
Status: "processing",
}, nil
}
if len(downloadURLs) == 0 {
return &SoraVideoTaskStatus{
ID: taskID,
Status: "processing",
}, nil
}
return &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: downloadURLs,
}, nil
}
func (c *SoraSDKClient) getVideoTaskDownloadURLs(ctx context.Context, account *Account, accessToken, taskID string) ([]string, error) {
raw, err := c.doSoraBackendJSON(ctx, account, http.MethodGet, "/project_y/profile/drafts?limit=30", accessToken, "", nil)
if err != nil {
return nil, err
}
items := gjson.GetBytes(raw, "items")
if !items.Exists() || !items.IsArray() {
return nil, fmt.Errorf("drafts response missing items for task %s", taskID)
}
urlSet := make(map[string]struct{}, 4)
urls := make([]string, 0, 4)
items.ForEach(func(_, item gjson.Result) bool {
if strings.TrimSpace(item.Get("task_id").String()) != taskID {
return true
}
kind := strings.TrimSpace(item.Get("kind").String())
reason := strings.TrimSpace(item.Get("reason_str").String())
markdownReason := strings.TrimSpace(item.Get("markdown_reason_str").String())
if kind == "sora_content_violation" || reason != "" || markdownReason != "" {
if reason == "" {
reason = markdownReason
}
if reason == "" {
reason = "内容违规"
}
err = fmt.Errorf("内容违规: %s", reason)
return false
}
url := strings.TrimSpace(item.Get("downloadable_url").String())
if url == "" {
url = strings.TrimSpace(item.Get("url").String())
}
if url == "" {
return true
}
if _, exists := urlSet[url]; exists {
return true
}
urlSet[url] = struct{}{}
urls = append(urls, url)
return true
})
if err != nil {
return nil, err
}
if len(urls) > 0 {
return urls, nil
}
// 兼容旧 SDK 的兜底逻辑
sdkClient, sdkErr := c.getSDKClient(account)
if sdkErr != nil {
return nil, sdkErr
}
downloadURL, sdkErr := sdkClient.GetDownloadURL(ctx, accessToken, taskID)
if sdkErr != nil {
return nil, sdkErr
}
if strings.TrimSpace(downloadURL) == "" {
return nil, nil
}
return []string{downloadURL}, nil
}
func (c *SoraSDKClient) doSoraBackendJSON(
ctx context.Context,
account *Account,
method string,
path string,
accessToken string,
sentinelToken string,
payload map[string]any,
) ([]byte, error) {
endpoint := "https://sora.chatgpt.com/backend" + path
var body io.Reader
if payload != nil {
raw, err := json.Marshal(payload)
if err != nil {
return nil, err
}
body = bytes.NewReader(raw)
}
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
if strings.TrimSpace(sentinelToken) != "" {
req.Header.Set("openai-sentinel-token", sentinelToken)
}
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateForLog(raw, 256))
}
return raw, nil
}
// --- 内部方法 ---
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
func (c *SoraSDKClient) getSDKClient(account *Account) (*sora.Client, error) {
proxyURL := c.resolveProxyURL(account)
if v, ok := c.sdkClients.Load(proxyURL); ok {
if cli, ok2 := v.(*sora.Client); ok2 {
return cli, nil
}
}
client, err := sora.New(proxyURL)
if err != nil {
return nil, fmt.Errorf("创建 Sora SDK 客户端失败: %w", err)
}
actual, _ := c.sdkClients.LoadOrStore(proxyURL, client)
if cli, ok := actual.(*sora.Client); ok {
return cli, nil
}
return client, nil
}
func (c *SoraSDKClient) resolveProxyURL(account *Account) string {
if account == nil || account.ProxyID == nil || account.Proxy == nil {
return ""
}
return strings.TrimSpace(account.Proxy.URL())
}
// getAccessToken 获取账号的 access_token,支持多种 token 来源和自动刷新。
// 此方法保留了原 SoraDirectClient 的 token 管理逻辑。
func (c *SoraSDKClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
// 优先尝试 OpenAI Token Provider
allowProvider := c.allowOpenAITokenProvider(account)
var providerErr error
if allowProvider && c.tokenProvider != nil {
token, err := c.tokenProvider.GetAccessToken(ctx, account)
if err == nil && strings.TrimSpace(token) != "" {
c.debugLogf("token_selected account_id=%d source=openai_token_provider", account.ID)
return token, nil
}
providerErr = err
if err != nil && c.debugEnabled() {
c.debugLogf("token_provider_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
// 尝试直接使用 credentials 中的 access_token
token := strings.TrimSpace(account.GetCredential("access_token"))
if token != "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
return refreshed, nil
}
}
return token, nil
}
// 尝试通过 session_token 或 refresh_token 恢复
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
return recovered, nil
}
if providerErr != nil {
return "", providerErr
}
return "", errors.New("access_token not found")
}
// recoverAccessToken 通过 session_token 或 refresh_token 恢复 access_token
func (c *SoraSDKClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
// 先尝试 session_token
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
if err == nil && strings.TrimSpace(accessToken) != "" {
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
return accessToken, nil
}
}
// 再尝试 refresh_token
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
if refreshToken == "" {
return "", errors.New("session_token/refresh_token not found")
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
// 尝试多个 client_id
clientIDs := []string{
strings.TrimSpace(account.GetCredential("client_id")),
openaioauth.SoraClientID,
openaioauth.ClientID,
}
tried := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
if clientID == "" {
continue
}
if _, ok := tried[clientID]; ok {
continue
}
tried[clientID] = struct{}{}
newAccess, newRefresh, refreshErr := sdkClient.RefreshAccessToken(ctx, refreshToken, clientID)
if refreshErr != nil {
lastErr = refreshErr
continue
}
if strings.TrimSpace(newAccess) == "" {
lastErr = errors.New("refreshed access_token is empty")
continue
}
c.applyRecoveredToken(ctx, account, newAccess, newRefresh, "", "")
return newAccess, nil
}
if lastErr != nil {
return "", lastErr
}
return "", errors.New("no available client_id for refresh_token exchange")
}
// exchangeSessionToken 通过 session_token 换取 access_token
func (c *SoraSDKClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://sora.chatgpt.com/api/auth/session", nil)
if err != nil {
return "", "", err
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", "", err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if err != nil {
return "", "", err
}
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("session exchange failed: %d", resp.StatusCode)
}
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
if accessToken == "" {
return "", "", errors.New("session exchange missing accessToken")
}
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
return accessToken, expiresAt, nil
}
// applyRecoveredToken 将恢复的 token 写入账号内存和数据库
func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
if account == nil {
return
}
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
if strings.TrimSpace(accessToken) != "" {
account.Credentials["access_token"] = accessToken
}
if strings.TrimSpace(refreshToken) != "" {
account.Credentials["refresh_token"] = refreshToken
}
if strings.TrimSpace(expiresAt) != "" {
account.Credentials["expires_at"] = expiresAt
}
if strings.TrimSpace(sessionToken) != "" {
account.Credentials["session_token"] = sessionToken
}
if c.accountRepo != nil {
if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
}
func (c *SoraSDKClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
return
}
updates := make(map[string]any)
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
updates["access_token"] = accessToken
updates["refresh_token"] = refreshToken
}
if strings.TrimSpace(sessionToken) != "" {
updates["session_token"] = sessionToken
}
if len(updates) == 0 {
return
}
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
func (c *SoraSDKClient) allowOpenAITokenProvider(account *Account) bool {
if c == nil || c.tokenProvider == nil {
return false
}
if account != nil && account.Platform == PlatformSora {
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
}
return true
}
// wrapSDKError 将 SDK 错误包装为 SoraUpstreamError
func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
if err == nil {
return nil
}
msg := err.Error()
statusCode := http.StatusBadGateway
if strings.Contains(msg, "HTTP 401") || strings.Contains(msg, "HTTP 403") {
statusCode = http.StatusUnauthorized
} else if strings.Contains(msg, "HTTP 429") {
statusCode = http.StatusTooManyRequests
} else if strings.Contains(msg, "HTTP 404") {
statusCode = http.StatusNotFound
}
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora_sdk",
"[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s",
accountID,
statusCode,
logredact.RedactText(msg),
)
return &SoraUpstreamError{
StatusCode: statusCode,
Message: msg,
}
}
func (c *SoraSDKClient) debugEnabled() bool {
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
}
func (c *SoraSDKClient) debugLogf(format string, args ...any) {
if c.debugEnabled() {
log.Printf("[SoraSDK] "+format, args...)
}
}
package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
// 支持流式和非流式响应的直接透传。
func (s *SoraGatewayService) forwardToUpstream(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
clientStream bool,
startTime time.Time,
) (*ForwardResult, error) {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
}
baseURL := account.GetBaseURL()
if baseURL == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
}
// 校验 scheme 合法性(仅允许 http/https)
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
}
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
// 构建上游请求
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
return nil, fmt.Errorf("create upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
// 透传客户端的部分请求头
for _, header := range []string{"Accept", "Accept-Encoding"} {
if v := c.GetHeader(header); v != "" {
upstreamReq.Header.Set(header, v)
}
}
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
}
}
defer func() {
_ = resp.Body.Close()
}()
// 错误响应处理
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
// 非转移错误,直接透传给客户端
c.Status(resp.StatusCode)
for key, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
if _, err := c.Writer.Write(respBody); err != nil {
return nil, fmt.Errorf("write upstream error response: %w", err)
}
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 成功响应 — 直接透传
c.Status(resp.StatusCode)
for key, values := range resp.Header {
lower := strings.ToLower(key)
// 透传内容相关头部
if lower == "content-type" || lower == "transfer-encoding" ||
lower == "cache-control" || lower == "x-request-id" {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
}
// 流式复制响应体
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, err := c.Writer.Write(buf[:n]); err != nil {
return nil, fmt.Errorf("stream upstream response write: %w", err)
}
flusher.Flush()
}
if readErr != nil {
break
}
}
} else {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return nil, fmt.Errorf("copy upstream response: %w", err)
}
}
duration := time.Since(startTime)
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Model: "", // 由调用方填充
Stream: clientStream,
Duration: duration,
}, nil
}
...@@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac ...@@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键 // Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account)) keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey) keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformOpenAI, PlatformSora: case PlatformOpenAI:
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account)) keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic: case PlatformAnthropic:
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account)) keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
......
...@@ -60,7 +60,6 @@ func NewTokenRefreshService( ...@@ -60,7 +60,6 @@ func NewTokenRefreshService(
} }
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
claudeRefresher := NewClaudeTokenRefresher(oauthService) claudeRefresher := NewClaudeTokenRefresher(oauthService)
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService) geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
...@@ -85,18 +84,6 @@ func NewTokenRefreshService( ...@@ -85,18 +84,6 @@ func NewTokenRefreshService(
return s return s
} }
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
// 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
openaiRefresher.SetSoraAccountRepo(repo)
}
}
}
// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖 // SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖
func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) { func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) {
s.privacyClientFactory = factory s.privacyClientFactory = factory
......
...@@ -2,7 +2,6 @@ package service ...@@ -2,7 +2,6 @@ package service
import ( import (
"context" "context"
"log"
"time" "time"
) )
...@@ -73,8 +72,6 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m ...@@ -73,8 +72,6 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
type OpenAITokenRefresher struct { type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
syncLinkedSora bool
} }
// NewOpenAITokenRefresher 创建 OpenAI token刷新器 // NewOpenAITokenRefresher 创建 OpenAI token刷新器
...@@ -90,20 +87,7 @@ func (r *OpenAITokenRefresher) CacheKey(account *Account) string { ...@@ -90,20 +87,7 @@ func (r *OpenAITokenRefresher) CacheKey(account *Account) string {
return OpenAITokenCacheKey(account) return OpenAITokenCacheKey(account)
} }
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 Token 刷新时同步更新 sora_accounts 表
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo
}
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
r.syncLinkedSora = enabled
}
// CanRefresh 检查是否能处理此账号 // CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
} }
...@@ -121,7 +105,6 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time ...@@ -121,7 +105,6 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time
// Refresh 执行token刷新 // Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段 // 保留原有credentials中的所有字段,只更新token相关字段
// 刷新成功后,异步同步关联的 Sora 账号
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account) tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
...@@ -132,68 +115,5 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m ...@@ -132,68 +115,5 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo) newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
newCredentials = MergeCredentials(account.Credentials, newCredentials) newCredentials = MergeCredentials(account.Credentials, newCredentials)
// 异步同步关联的 Sora 账号(不阻塞主流程)
if r.accountRepo != nil && r.syncLinkedSora {
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
}
return newCredentials, nil return newCredentials, nil
} }
// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步)
// 该方法异步执行,失败只记录日志,不影响主流程
//
// 同步策略:
// 1. 更新 accounts.credentials(主表)
// 2. 更新 sora_accounts 扩展表(如果 soraAccountRepo 已设置)
//
// 超时控制:30 秒,防止数据库阻塞导致 goroutine 泄漏
func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
// 添加超时控制,防止 goroutine 泄漏
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// 1. 查找所有关联的 Sora 账号(限定 platform='sora')
soraAccounts, err := r.accountRepo.FindByExtraField(ctx, "linked_openai_account_id", openaiAccountID)
if err != nil {
log.Printf("[TokenSync] 查找关联 Sora 账号失败: openai_account_id=%d err=%v", openaiAccountID, err)
return
}
if len(soraAccounts) == 0 {
// 没有关联的 Sora 账号,直接返回
return
}
// 2. 同步更新每个 Sora 账号的双表数据
for _, soraAccount := range soraAccounts {
// 2.1 更新 accounts.credentials(主表)
soraAccount.Credentials["access_token"] = newCredentials["access_token"]
soraAccount.Credentials["refresh_token"] = newCredentials["refresh_token"]
if expiresAt, ok := newCredentials["expires_at"]; ok {
soraAccount.Credentials["expires_at"] = expiresAt
}
if err := r.accountRepo.Update(ctx, &soraAccount); err != nil {
log.Printf("[TokenSync] 更新 Sora accounts 表失败: sora_account_id=%d openai_account_id=%d err=%v",
soraAccount.ID, openaiAccountID, err)
continue
}
// 2.2 更新 sora_accounts 扩展表(如果仓储已设置)
if r.soraAccountRepo != nil {
soraUpdates := map[string]any{
"access_token": newCredentials["access_token"],
"refresh_token": newCredentials["refresh_token"],
}
if err := r.soraAccountRepo.Upsert(ctx, soraAccount.ID, soraUpdates); err != nil {
log.Printf("[TokenSync] 更新 sora_accounts 表失败: account_id=%d openai_account_id=%d err=%v",
soraAccount.ID, openaiAccountID, err)
// 继续处理其他账号,不中断
}
}
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
}
}
...@@ -242,12 +242,6 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) { ...@@ -242,12 +242,6 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
accType: AccountTypeOAuth, accType: AccountTypeOAuth,
want: true, want: true,
}, },
{
name: "sora oauth - cannot refresh directly",
platform: PlatformSora,
accType: AccountTypeOAuth,
want: false,
},
{ {
name: "openai apikey - cannot refresh", name: "openai apikey - cannot refresh",
platform: PlatformOpenAI, platform: PlatformOpenAI,
......
...@@ -110,7 +110,7 @@ type UsageLog struct { ...@@ -110,7 +110,7 @@ type UsageLog struct {
ModelMappingChain *string ModelMappingChain *string
// BillingTier 计费层级标签(per_request/image 模式) // BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string BillingTier *string
// BillingMode 计费模式:token/image(sora 路径为 nil) // BillingMode 计费模式:token/image
BillingMode *string BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
......
...@@ -25,10 +25,6 @@ type User struct { ...@@ -25,10 +25,6 @@ type User struct {
// map[groupID]rateMultiplier // map[groupID]rateMultiplier
GroupRates map[int64]float64 GroupRates map[int64]float64
// Sora 存储配额
SoraStorageQuotaBytes int64 // 用户级 Sora 存储配额(0 表示使用分组或系统默认值)
SoraStorageUsedBytes int64 // Sora 存储已用量
// TOTP 双因素认证字段 // TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP TotpEnabled bool // 是否启用 TOTP
......
...@@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { ...@@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
// ProvideTokenRefreshService creates and starts TokenRefreshService // ProvideTokenRefreshService creates and starts TokenRefreshService
func ProvideTokenRefreshService( func ProvideTokenRefreshService(
accountRepo AccountRepository, accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
...@@ -54,8 +53,6 @@ func ProvideTokenRefreshService( ...@@ -54,8 +53,6 @@ func ProvideTokenRefreshService(
refreshAPI *OAuthRefreshAPI, refreshAPI *OAuthRefreshAPI,
) *TokenRefreshService { ) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo)
// 注入 OpenAI privacy opt-out 依赖 // 注入 OpenAI privacy opt-out 依赖
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo) svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件) // 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
...@@ -281,30 +278,6 @@ func ProvideOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink { ...@@ -281,30 +278,6 @@ func ProvideOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink {
return sink return sink
} }
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
func ProvideSoraSDKClient(
cfg *config.Config,
httpUpstream HTTPUpstream,
tokenProvider *OpenAITokenProvider,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
) *SoraSDKClient {
client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider)
client.SetAccountRepositories(accountRepo, soraAccountRepo)
return client
}
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
return svc
}
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig { func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
idempotencyCfg := DefaultIdempotencyConfig() idempotencyCfg := DefaultIdempotencyConfig()
if cfg != nil { if cfg != nil {
...@@ -425,11 +398,6 @@ var ProviderSet = wire.NewSet( ...@@ -425,11 +398,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementService, NewAnnouncementService,
NewAdminService, NewAdminService,
NewGatewayService, NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
ProvideSoraSDKClient,
wire.Bind(new(SoraClient), new(*SoraSDKClient)),
NewSoraGatewayService,
NewOpenAIGatewayService, NewOpenAIGatewayService,
NewOAuthService, NewOAuthService,
NewOpenAIOAuthService, NewOpenAIOAuthService,
......
package soraerror package httputil
import ( import (
"encoding/json" "encoding/json"
......
package soraerror
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsCloudflareChallengeResponse(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-mitigated", "challenge")
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
}
func TestExtractCloudflareRayID(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
}
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
require.Equal(t, "cf_shield_429", code)
require.Equal(t, "rate limited", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
require.Equal(t, "unsupported_country_code", code)
require.Equal(t, "not available", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
require.Equal(t, "", code)
require.Equal(t, "plain text", msg)
}
func TestFormatCloudflareChallengeMessage(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
}
...@@ -256,7 +256,6 @@ func shouldBypassEmbeddedFrontend(path string) bool { ...@@ -256,7 +256,6 @@ func shouldBypassEmbeddedFrontend(path string) bool {
return strings.HasPrefix(trimmed, "/api/") || return strings.HasPrefix(trimmed, "/api/") ||
strings.HasPrefix(trimmed, "/v1/") || strings.HasPrefix(trimmed, "/v1/") ||
strings.HasPrefix(trimmed, "/v1beta/") || strings.HasPrefix(trimmed, "/v1beta/") ||
strings.HasPrefix(trimmed, "/sora/") ||
strings.HasPrefix(trimmed, "/antigravity/") || strings.HasPrefix(trimmed, "/antigravity/") ||
strings.HasPrefix(trimmed, "/setup/") || strings.HasPrefix(trimmed, "/setup/") ||
trimmed == "/health" || trimmed == "/health" ||
......
...@@ -434,7 +434,6 @@ func TestFrontendServer_Middleware(t *testing.T) { ...@@ -434,7 +434,6 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users", "/api/v1/users",
"/v1/models", "/v1/models",
"/v1beta/chat", "/v1beta/chat",
"/sora/v1/models",
"/antigravity/test", "/antigravity/test",
"/setup/init", "/setup/init",
"/health", "/health",
...@@ -637,7 +636,6 @@ func TestServeEmbeddedFrontend(t *testing.T) { ...@@ -637,7 +636,6 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users", "/api/users",
"/v1/models", "/v1/models",
"/v1beta/chat", "/v1beta/chat",
"/sora/v1/models",
"/antigravity/test", "/antigravity/test",
"/setup/init", "/setup/init",
"/health", "/health",
......
-- Migration: 090_drop_sora
-- Remove all Sora-related database objects.
-- Drops tables: sora_tasks, sora_generations, sora_accounts
-- Drops columns from: groups, users, usage_logs
-- ============================================================
-- 1. Drop Sora tables
-- ============================================================
DROP TABLE IF EXISTS sora_tasks;
DROP TABLE IF EXISTS sora_generations;
DROP TABLE IF EXISTS sora_accounts;
-- ============================================================
-- 2. Drop Sora columns from groups table
-- ============================================================
ALTER TABLE groups
DROP COLUMN IF EXISTS sora_image_price_360,
DROP COLUMN IF EXISTS sora_image_price_540,
DROP COLUMN IF EXISTS sora_video_price_per_request,
DROP COLUMN IF EXISTS sora_video_price_per_request_hd,
DROP COLUMN IF EXISTS sora_storage_quota_bytes;
-- ============================================================
-- 3. Drop Sora columns from users table
-- ============================================================
ALTER TABLE users
DROP COLUMN IF EXISTS sora_storage_quota_bytes,
DROP COLUMN IF EXISTS sora_storage_used_bytes;
-- ============================================================
-- 4. Drop Sora column from usage_logs table
-- ============================================================
ALTER TABLE usage_logs
DROP COLUMN IF EXISTS media_type;
import { describe, expect, it } from 'vitest'
import {
normalizeGenerationListResponse,
normalizeModelFamiliesResponse
} from '../sora'
describe('sora api normalizers', () => {
it('normalizes generation list from data shape', () => {
const result = normalizeGenerationListResponse({
data: [{ id: 1, status: 'pending' }],
total: 9,
page: 2
})
expect(result.data).toHaveLength(1)
expect(result.total).toBe(9)
expect(result.page).toBe(2)
})
it('normalizes generation list from items shape', () => {
const result = normalizeGenerationListResponse({
items: [{ id: 1, status: 'completed' }],
total: 1
})
expect(result.data).toHaveLength(1)
expect(result.total).toBe(1)
expect(result.page).toBe(1)
})
it('falls back to empty generation list on invalid payload', () => {
const result = normalizeGenerationListResponse(null)
expect(result).toEqual({ data: [], total: 0, page: 1 })
})
it('normalizes family model payload', () => {
const result = normalizeModelFamiliesResponse({
data: [
{
id: 'sora2',
name: 'Sora 2',
type: 'video',
orientations: ['landscape', 'portrait'],
durations: [10, 15]
}
]
})
expect(result).toHaveLength(1)
expect(result[0].id).toBe('sora2')
expect(result[0].orientations).toEqual(['landscape', 'portrait'])
expect(result[0].durations).toEqual([10, 15])
})
it('normalizes legacy flat model list into families', () => {
const result = normalizeModelFamiliesResponse({
items: [
{ id: 'sora2-landscape-10s', type: 'video' },
{ id: 'sora2-portrait-15s', type: 'video' },
{ id: 'gpt-image-square', type: 'image' }
]
})
const sora2 = result.find((m) => m.id === 'sora2')
expect(sora2).toBeTruthy()
expect(sora2?.orientations).toEqual(['landscape', 'portrait'])
expect(sora2?.durations).toEqual([10, 15])
const image = result.find((m) => m.id === 'gpt-image')
expect(image).toBeTruthy()
expect(image?.type).toBe('image')
expect(image?.orientations).toEqual(['square'])
})
it('falls back to empty families on invalid payload', () => {
expect(normalizeModelFamiliesResponse(undefined)).toEqual([])
expect(normalizeModelFamiliesResponse({})).toEqual([])
})
})
...@@ -568,28 +568,6 @@ export async function refreshOpenAIToken( ...@@ -568,28 +568,6 @@ export async function refreshOpenAIToken(
return data return data
} }
/**
* Validate Sora session token and exchange to access token
* @param sessionToken - Sora session token
* @param proxyId - Optional proxy ID
* @param endpoint - API endpoint path
* @returns Token information including access_token
*/
export async function validateSoraSessionToken(
sessionToken: string,
proxyId?: number | null,
endpoint: string = '/admin/sora/st2at'
): Promise<Record<string, unknown>> {
const payload: { session_token: string; proxy_id?: number } = {
session_token: sessionToken
}
if (proxyId) {
payload.proxy_id = proxyId
}
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data
}
/** /**
* Batch operation result type * Batch operation result type
*/ */
...@@ -663,7 +641,6 @@ export const accountsAPI = { ...@@ -663,7 +641,6 @@ export const accountsAPI = {
generateAuthUrl, generateAuthUrl,
exchangeCode, exchangeCode,
refreshOpenAIToken, refreshOpenAIToken,
validateSoraSessionToken,
batchCreate, batchCreate,
batchUpdateCredentials, batchUpdateCredentials,
bulkUpdate, bulkUpdate,
......
...@@ -40,7 +40,6 @@ export interface SystemSettings { ...@@ -40,7 +40,6 @@ export interface SystemSettings {
hide_ccs_import_button: boolean hide_ccs_import_button: boolean
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
sora_client_enabled: boolean
backend_mode_enabled: boolean backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[] custom_endpoints: CustomEndpoint[]
...@@ -114,7 +113,6 @@ export interface UpdateSettingsRequest { ...@@ -114,7 +113,6 @@ export interface UpdateSettingsRequest {
hide_ccs_import_button?: boolean hide_ccs_import_button?: boolean
purchase_subscription_enabled?: boolean purchase_subscription_enabled?: boolean
purchase_subscription_url?: string purchase_subscription_url?: string
sora_client_enabled?: boolean
backend_mode_enabled?: boolean backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[] custom_menu_items?: CustomMenuItem[]
custom_endpoints?: CustomEndpoint[] custom_endpoints?: CustomEndpoint[]
...@@ -394,142 +392,6 @@ export async function updateBetaPolicySettings( ...@@ -394,142 +392,6 @@ export async function updateBetaPolicySettings(
return data return data
} }
// ==================== Sora S3 Settings ====================
export interface SoraS3Settings {
enabled: boolean
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key_configured: boolean
prefix: string
force_path_style: boolean
cdn_url: string
default_storage_quota_bytes: number
}
export interface SoraS3Profile {
profile_id: string
name: string
is_active: boolean
enabled: boolean
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key_configured: boolean
prefix: string
force_path_style: boolean
cdn_url: string
default_storage_quota_bytes: number
updated_at: string
}
export interface ListSoraS3ProfilesResponse {
active_profile_id: string
items: SoraS3Profile[]
}
export interface UpdateSoraS3SettingsRequest {
profile_id?: string
enabled: boolean
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key?: string
prefix: string
force_path_style: boolean
cdn_url: string
default_storage_quota_bytes: number
}
export interface CreateSoraS3ProfileRequest {
profile_id: string
name: string
set_active?: boolean
enabled: boolean
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key?: string
prefix: string
force_path_style: boolean
cdn_url: string
default_storage_quota_bytes: number
}
export interface UpdateSoraS3ProfileRequest {
name: string
enabled: boolean
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key?: string
prefix: string
force_path_style: boolean
cdn_url: string
default_storage_quota_bytes: number
}
export interface TestSoraS3ConnectionRequest {
profile_id?: string
enabled: boolean
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key?: string
prefix: string
force_path_style: boolean
cdn_url: string
default_storage_quota_bytes?: number
}
export async function getSoraS3Settings(): Promise<SoraS3Settings> {
const { data } = await apiClient.get<SoraS3Settings>('/admin/settings/sora-s3')
return data
}
export async function updateSoraS3Settings(settings: UpdateSoraS3SettingsRequest): Promise<SoraS3Settings> {
const { data } = await apiClient.put<SoraS3Settings>('/admin/settings/sora-s3', settings)
return data
}
export async function testSoraS3Connection(
settings: TestSoraS3ConnectionRequest
): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>('/admin/settings/sora-s3/test', settings)
return data
}
export async function listSoraS3Profiles(): Promise<ListSoraS3ProfilesResponse> {
const { data } = await apiClient.get<ListSoraS3ProfilesResponse>('/admin/settings/sora-s3/profiles')
return data
}
export async function createSoraS3Profile(request: CreateSoraS3ProfileRequest): Promise<SoraS3Profile> {
const { data } = await apiClient.post<SoraS3Profile>('/admin/settings/sora-s3/profiles', request)
return data
}
export async function updateSoraS3Profile(profileID: string, request: UpdateSoraS3ProfileRequest): Promise<SoraS3Profile> {
const { data } = await apiClient.put<SoraS3Profile>(`/admin/settings/sora-s3/profiles/${profileID}`, request)
return data
}
export async function deleteSoraS3Profile(profileID: string): Promise<void> {
await apiClient.delete(`/admin/settings/sora-s3/profiles/${profileID}`)
}
export async function setActiveSoraS3Profile(profileID: string): Promise<SoraS3Profile> {
const { data } = await apiClient.post<SoraS3Profile>(`/admin/settings/sora-s3/profiles/${profileID}/activate`)
return data
}
export const settingsAPI = { export const settingsAPI = {
getSettings, getSettings,
updateSettings, updateSettings,
...@@ -545,15 +407,7 @@ export const settingsAPI = { ...@@ -545,15 +407,7 @@ export const settingsAPI = {
getRectifierSettings, getRectifierSettings,
updateRectifierSettings, updateRectifierSettings,
getBetaPolicySettings, getBetaPolicySettings,
updateBetaPolicySettings, updateBetaPolicySettings
getSoraS3Settings,
updateSoraS3Settings,
testSoraS3Connection,
listSoraS3Profiles,
createSoraS3Profile,
updateSoraS3Profile,
deleteSoraS3Profile,
setActiveSoraS3Profile
} }
export default settingsAPI export default settingsAPI
/**
* Sora 客户端 API
* 封装所有 Sora 生成、作品库、配额等接口调用
*/
import { apiClient } from './client'
// ==================== 类型定义 ====================
export interface SoraGeneration {
id: number
user_id: number
model: string
prompt: string
media_type: string
status: string // pending | generating | completed | failed | cancelled
storage_type: string // upstream | s3 | local
media_url: string
media_urls: string[]
s3_object_keys: string[]
file_size_bytes: number
error_message: string
created_at: string
completed_at?: string
}
export interface GenerateRequest {
model: string
prompt: string
video_count?: number
media_type?: string
image_input?: string
api_key_id?: number
}
export interface GenerateResponse {
generation_id: number
status: string
}
export interface GenerationListResponse {
data: SoraGeneration[]
total: number
page: number
}
export interface QuotaInfo {
quota_bytes: number
used_bytes: number
available_bytes: number
quota_source: string // user | group | system | unlimited
source?: string // 兼容旧字段
}
export interface StorageStatus {
s3_enabled: boolean
s3_healthy: boolean
local_enabled: boolean
}
/** 单个扁平模型(旧接口,保留兼容) */
export interface SoraModel {
id: string
name: string
type: string // video | image
orientation?: string
duration?: number
}
/** 模型家族(新接口 — 后端从 soraModelConfigs 自动聚合) */
export interface SoraModelFamily {
id: string // 家族 ID,如 "sora2"
name: string // 显示名,如 "Sora 2"
type: string // "video" | "image"
orientations: string[] // ["landscape", "portrait"] 或 ["landscape", "portrait", "square"]
durations?: number[] // [10, 15, 25](仅视频模型)
}
type LooseRecord = Record<string, unknown>
function asRecord(value: unknown): LooseRecord | null {
return value !== null && typeof value === 'object' ? value as LooseRecord : null
}
function asArray<T = unknown>(value: unknown): T[] {
return Array.isArray(value) ? value as T[] : []
}
function asPositiveInt(value: unknown): number | null {
const n = Number(value)
if (!Number.isFinite(n) || n <= 0) return null
return Math.round(n)
}
function dedupeStrings(values: string[]): string[] {
return Array.from(new Set(values))
}
function extractOrientationFromModelID(modelID: string): string | null {
const m = modelID.match(/-(landscape|portrait|square)(?:-\d+s)?$/i)
return m ? m[1].toLowerCase() : null
}
function extractDurationFromModelID(modelID: string): number | null {
const m = modelID.match(/-(\d+)s$/i)
return m ? asPositiveInt(m[1]) : null
}
function normalizeLegacyFamilies(candidates: unknown[]): SoraModelFamily[] {
const familyMap = new Map<string, SoraModelFamily>()
for (const item of candidates) {
const model = asRecord(item)
if (!model || typeof model.id !== 'string' || model.id.trim() === '') continue
const rawID = model.id.trim()
const type = model.type === 'image' ? 'image' : 'video'
const name = typeof model.name === 'string' && model.name.trim() ? model.name.trim() : rawID
const baseID = rawID.replace(/-(landscape|portrait|square)(?:-\d+s)?$/i, '')
const orientation =
typeof model.orientation === 'string' && model.orientation
? model.orientation.toLowerCase()
: extractOrientationFromModelID(rawID)
const duration = asPositiveInt(model.duration) ?? extractDurationFromModelID(rawID)
const familyKey = baseID || rawID
const family = familyMap.get(familyKey) ?? {
id: familyKey,
name,
type,
orientations: [],
durations: []
}
if (orientation) {
family.orientations.push(orientation)
}
if (type === 'video' && duration) {
family.durations = family.durations || []
family.durations.push(duration)
}
familyMap.set(familyKey, family)
}
return Array.from(familyMap.values())
.map((family) => ({
...family,
orientations:
family.orientations.length > 0
? dedupeStrings(family.orientations)
: (family.type === 'image' ? ['square'] : ['landscape']),
durations:
family.type === 'video'
? Array.from(new Set((family.durations || []).filter((d): d is number => Number.isFinite(d)))).sort((a, b) => a - b)
: []
}))
.filter((family) => family.id !== '')
}
function normalizeModelFamilyRecord(item: unknown): SoraModelFamily | null {
const model = asRecord(item)
if (!model || typeof model.id !== 'string' || model.id.trim() === '') return null
// 仅把明确的“家族结构”识别为 family;老结构(单模型)走 legacy 聚合逻辑。
if (!Array.isArray(model.orientations) && !Array.isArray(model.durations)) return null
const orientations = asArray<string>(model.orientations).filter((o): o is string => typeof o === 'string' && o.length > 0)
const durations = asArray<unknown>(model.durations)
.map(asPositiveInt)
.filter((d): d is number => d !== null)
return {
id: model.id.trim(),
name: typeof model.name === 'string' && model.name.trim() ? model.name.trim() : model.id.trim(),
type: model.type === 'image' ? 'image' : 'video',
orientations: dedupeStrings(orientations),
durations: Array.from(new Set(durations)).sort((a, b) => a - b)
}
}
function extractCandidateArray(payload: unknown): unknown[] {
if (Array.isArray(payload)) return payload
const record = asRecord(payload)
if (!record) return []
const keys: Array<keyof LooseRecord> = ['data', 'items', 'models', 'families']
for (const key of keys) {
if (Array.isArray(record[key])) {
return record[key] as unknown[]
}
}
return []
}
export function normalizeModelFamiliesResponse(payload: unknown): SoraModelFamily[] {
const candidates = extractCandidateArray(payload)
if (candidates.length === 0) return []
const normalized = candidates
.map(normalizeModelFamilyRecord)
.filter((item): item is SoraModelFamily => item !== null)
if (normalized.length > 0) return normalized
return normalizeLegacyFamilies(candidates)
}
export function normalizeGenerationListResponse(payload: unknown): GenerationListResponse {
const record = asRecord(payload)
if (!record) {
return { data: [], total: 0, page: 1 }
}
const data = Array.isArray(record.data)
? (record.data as SoraGeneration[])
: Array.isArray(record.items)
? (record.items as SoraGeneration[])
: []
const total = Number(record.total)
const page = Number(record.page)
return {
data,
total: Number.isFinite(total) ? total : data.length,
page: Number.isFinite(page) && page > 0 ? page : 1
}
}
// ==================== API 方法 ====================
/** 异步生成 — 创建 pending 记录后立即返回 */
export async function generate(req: GenerateRequest): Promise<GenerateResponse> {
const { data } = await apiClient.post<GenerateResponse>('/sora/generate', req)
return data
}
/** 查询生成记录列表 */
export async function listGenerations(params?: {
page?: number
page_size?: number
status?: string
storage_type?: string
media_type?: string
}): Promise<GenerationListResponse> {
const { data } = await apiClient.get<unknown>('/sora/generations', { params })
return normalizeGenerationListResponse(data)
}
/** 查询生成记录详情 */
export async function getGeneration(id: number): Promise<SoraGeneration> {
const { data } = await apiClient.get<SoraGeneration>(`/sora/generations/${id}`)
return data
}
/** 删除生成记录 */
export async function deleteGeneration(id: number): Promise<{ message: string }> {
const { data } = await apiClient.delete<{ message: string }>(`/sora/generations/${id}`)
return data
}
/** 取消生成任务 */
export async function cancelGeneration(id: number): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>(`/sora/generations/${id}/cancel`)
return data
}
/** 手动保存到 S3 */
export async function saveToStorage(
id: number
): Promise<{ message: string; object_key: string; object_keys?: string[] }> {
const { data } = await apiClient.post<{ message: string; object_key: string; object_keys?: string[] }>(
`/sora/generations/${id}/save`
)
return data
}
/** 查询配额信息 */
export async function getQuota(): Promise<QuotaInfo> {
const { data } = await apiClient.get<QuotaInfo>('/sora/quota')
return data
}
/** 获取可用模型家族列表 */
export async function getModels(): Promise<SoraModelFamily[]> {
const { data } = await apiClient.get<unknown>('/sora/models')
return normalizeModelFamiliesResponse(data)
}
/** 获取存储状态 */
export async function getStorageStatus(): Promise<StorageStatus> {
const { data } = await apiClient.get<StorageStatus>('/sora/storage-status')
return data
}
const soraAPI = {
generate,
listGenerations,
getGeneration,
deleteGeneration,
cancelGeneration,
saveToStorage,
getQuota,
getModels,
getStorageStatus
}
export default soraAPI
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment