Commit eb2dce92 authored by 陈曦's avatar 陈曦
Browse files

升级v1.0.8 解决冲突

parents 7b83d6e7 339d906e
......@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai/sora oauth account")
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
......@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai/sora oauth account")
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
......
......@@ -22,8 +22,6 @@ import (
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
"default subscription group must exist and be subscription type",
......@@ -104,7 +102,6 @@ type SettingService struct {
defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
onS3Update func() // Callback when Sora S3 settings are updated
version string // Application version
}
......@@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
......@@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
......@@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
s.onS3Update = callback
}
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
......@@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
......@@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
......@@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
......@@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
......@@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
......@@ -1583,607 +1568,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
type soraS3ProfilesStore struct {
ActiveProfileID string `json:"active_profile_id"`
Items []soraS3ProfileStoreItem `json:"items"`
}
type soraS3ProfileStoreItem struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
profiles, err := s.ListSoraS3Profiles(ctx)
if err != nil {
return nil, err
}
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if activeProfile == nil {
return &SoraS3Settings{}, nil
}
return &SoraS3Settings{
Enabled: activeProfile.Enabled,
Endpoint: activeProfile.Endpoint,
Region: activeProfile.Region,
Bucket: activeProfile.Bucket,
AccessKeyID: activeProfile.AccessKeyID,
SecretAccessKey: activeProfile.SecretAccessKey,
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
Prefix: activeProfile.Prefix,
ForcePathStyle: activeProfile.ForcePathStyle,
CDNURL: activeProfile.CDNURL,
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
}, nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
now := time.Now().UTC().Format(time.RFC3339)
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
if activeIndex < 0 {
activeID := "default"
if hasSoraS3ProfileID(store.Items, activeID) {
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
}
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: activeID,
Name: "Default",
UpdatedAt: now,
})
store.ActiveProfileID = activeID
activeIndex = len(store.Items) - 1
}
active := store.Items[activeIndex]
active.Enabled = settings.Enabled
active.Endpoint = strings.TrimSpace(settings.Endpoint)
active.Region = strings.TrimSpace(settings.Region)
active.Bucket = strings.TrimSpace(settings.Bucket)
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
active.Prefix = strings.TrimSpace(settings.Prefix)
active.ForcePathStyle = settings.ForcePathStyle
active.CDNURL = strings.TrimSpace(settings.CDNURL)
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
if settings.SecretAccessKey != "" {
active.SecretAccessKey = settings.SecretAccessKey
}
active.UpdatedAt = now
store.Items[activeIndex] = active
return s.persistSoraS3ProfilesStore(ctx, store)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
return convertSoraS3ProfilesStore(store), nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
profileID := strings.TrimSpace(profile.ProfileID)
if profileID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
if hasSoraS3ProfileID(store.Items, profileID) {
return nil, ErrSoraS3ProfileExists
}
now := time.Now().UTC().Format(time.RFC3339)
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: profileID,
Name: name,
Enabled: profile.Enabled,
Endpoint: strings.TrimSpace(profile.Endpoint),
Region: strings.TrimSpace(profile.Region),
Bucket: strings.TrimSpace(profile.Bucket),
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
SecretAccessKey: profile.SecretAccessKey,
Prefix: strings.TrimSpace(profile.Prefix),
ForcePathStyle: profile.ForcePathStyle,
CDNURL: strings.TrimSpace(profile.CDNURL),
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
})
if setActive || store.ActiveProfileID == "" {
store.ActiveProfileID = profileID
}
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
created := findSoraS3ProfileByID(profiles.Items, profileID)
if created == nil {
return nil, ErrSoraS3ProfileNotFound
}
return created, nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
target := store.Items[targetIndex]
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
target.Name = name
target.Enabled = profile.Enabled
target.Endpoint = strings.TrimSpace(profile.Endpoint)
target.Region = strings.TrimSpace(profile.Region)
target.Bucket = strings.TrimSpace(profile.Bucket)
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
target.Prefix = strings.TrimSpace(profile.Prefix)
target.ForcePathStyle = profile.ForcePathStyle
target.CDNURL = strings.TrimSpace(profile.CDNURL)
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
if profile.SecretAccessKey != "" {
target.SecretAccessKey = profile.SecretAccessKey
}
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
store.Items[targetIndex] = target
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
updated := findSoraS3ProfileByID(profiles.Items, targetID)
if updated == nil {
return nil, ErrSoraS3ProfileNotFound
}
return updated, nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return ErrSoraS3ProfileNotFound
}
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
if store.ActiveProfileID == targetID {
store.ActiveProfileID = ""
if len(store.Items) > 0 {
store.ActiveProfileID = store.Items[0].ProfileID
}
}
return s.persistSoraS3ProfilesStore(ctx, store)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
store.ActiveProfileID = targetID
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if active == nil {
return nil, ErrSoraS3ProfileNotFound
}
return active, nil
}
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
if err == nil {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return &soraS3ProfilesStore{}, nil
}
var store soraS3ProfilesStore
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
normalized := normalizeSoraS3ProfilesStore(store)
return &normalized, nil
}
if !errors.Is(err, ErrSettingNotFound) {
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
}
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, legacyErr
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
if store == nil {
return fmt.Errorf("sora s3 profiles store cannot be nil")
}
normalized := normalizeSoraS3ProfilesStore(*store)
data, err := json.Marshal(normalized)
if err != nil {
return fmt.Errorf("marshal sora s3 profiles: %w", err)
}
updates := map[string]string{
SettingKeySoraS3Profiles: string(data),
}
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
if active == nil {
updates[SettingKeySoraS3Enabled] = "false"
updates[SettingKeySoraS3Endpoint] = ""
updates[SettingKeySoraS3Region] = ""
updates[SettingKeySoraS3Bucket] = ""
updates[SettingKeySoraS3AccessKeyID] = ""
updates[SettingKeySoraS3Prefix] = ""
updates[SettingKeySoraS3ForcePathStyle] = "false"
updates[SettingKeySoraS3CDNURL] = ""
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
updates[SettingKeySoraS3SecretAccessKey] = ""
} else {
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
}
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return err
}
if s.onUpdate != nil {
s.onUpdate()
}
if s.onS3Update != nil {
s.onS3Update()
}
return nil
}
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
keys := []string{
SettingKeySoraS3Enabled,
SettingKeySoraS3Endpoint,
SettingKeySoraS3Region,
SettingKeySoraS3Bucket,
SettingKeySoraS3AccessKeyID,
SettingKeySoraS3SecretAccessKey,
SettingKeySoraS3Prefix,
SettingKeySoraS3ForcePathStyle,
SettingKeySoraS3CDNURL,
SettingKeySoraDefaultStorageQuotaBytes,
}
values, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
}
result := &SoraS3Settings{
Enabled: values[SettingKeySoraS3Enabled] == "true",
Endpoint: values[SettingKeySoraS3Endpoint],
Region: values[SettingKeySoraS3Region],
Bucket: values[SettingKeySoraS3Bucket],
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
Prefix: values[SettingKeySoraS3Prefix],
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
CDNURL: values[SettingKeySoraS3CDNURL],
}
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
result.DefaultStorageQuotaBytes = v
}
return result, nil
}
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
seen := make(map[string]struct{}, len(store.Items))
normalized := soraS3ProfilesStore{
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
}
now := time.Now().UTC().Format(time.RFC3339)
for idx := range store.Items {
item := store.Items[idx]
item.ProfileID = strings.TrimSpace(item.ProfileID)
if item.ProfileID == "" {
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
}
if _, exists := seen[item.ProfileID]; exists {
continue
}
seen[item.ProfileID] = struct{}{}
item.Name = strings.TrimSpace(item.Name)
if item.Name == "" {
item.Name = item.ProfileID
}
item.Endpoint = strings.TrimSpace(item.Endpoint)
item.Region = strings.TrimSpace(item.Region)
item.Bucket = strings.TrimSpace(item.Bucket)
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
item.Prefix = strings.TrimSpace(item.Prefix)
item.CDNURL = strings.TrimSpace(item.CDNURL)
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
if item.UpdatedAt == "" {
item.UpdatedAt = now
}
normalized.Items = append(normalized.Items, item)
}
if len(normalized.Items) == 0 {
normalized.ActiveProfileID = ""
return normalized
}
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
return normalized
}
normalized.ActiveProfileID = normalized.Items[0].ProfileID
return normalized
}
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
if store == nil {
return &SoraS3ProfileList{}
}
items := make([]SoraS3Profile, 0, len(store.Items))
for idx := range store.Items {
item := store.Items[idx]
items = append(items, SoraS3Profile{
ProfileID: item.ProfileID,
Name: item.Name,
IsActive: item.ProfileID == store.ActiveProfileID,
Enabled: item.Enabled,
Endpoint: item.Endpoint,
Region: item.Region,
Bucket: item.Bucket,
AccessKeyID: item.AccessKeyID,
SecretAccessKey: item.SecretAccessKey,
SecretAccessKeyConfigured: item.SecretAccessKey != "",
Prefix: item.Prefix,
ForcePathStyle: item.ForcePathStyle,
CDNURL: item.CDNURL,
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
UpdatedAt: item.UpdatedAt,
})
}
return &SoraS3ProfileList{
ActiveProfileID: store.ActiveProfileID,
Items: items,
}
}
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
for idx := range items {
if items[idx].ProfileID == profileID {
return idx
}
}
return -1
}
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
return findSoraS3ProfileIndex(items, profileID) >= 0
}
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
if settings == nil {
return true
}
if settings.Enabled {
return false
}
if strings.TrimSpace(settings.Endpoint) != "" {
return false
}
if strings.TrimSpace(settings.Region) != "" {
return false
}
if strings.TrimSpace(settings.Bucket) != "" {
return false
}
if strings.TrimSpace(settings.AccessKeyID) != "" {
return false
}
if settings.SecretAccessKey != "" {
return false
}
if strings.TrimSpace(settings.Prefix) != "" {
return false
}
if strings.TrimSpace(settings.CDNURL) != "" {
return false
}
return settings.DefaultStorageQuotaBytes == 0
}
func maxInt64(value int64, min int64) int64 {
if value < min {
return min
}
return value
}
......@@ -41,7 +41,6 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
......@@ -107,7 +106,6 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
......@@ -116,46 +114,6 @@ type PublicSettings struct {
Version string
}
// SoraS3Settings Sora S3 存储配置
type SoraS3Settings struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type SoraS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type SoraS3ProfileList struct {
ActiveProfileID string `json:"active_profile_id"`
Items []SoraS3Profile `json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
......
package service
import "context"
// SoraAccountRepository Sora 账号扩展表仓储接口
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
//
// 设计说明:
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
// - Token 刷新时需要同时更新两个表以保持数据一致性
type SoraAccountRepository interface {
// Upsert 创建或更新 Sora 账号扩展信息
// accountID: 关联的 accounts.id
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
//
// 如果记录不存在则创建,存在则更新。
// 用于:
// 1. 创建 Sora 账号时初始化扩展表
// 2. Token 刷新时同步更新扩展表
Upsert(ctx context.Context, accountID int64, updates map[string]any) error
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
// 返回 nil, nil 表示记录不存在(非错误)
GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
// Delete 删除 Sora 账号扩展信息
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
Delete(ctx context.Context, accountID int64) error
}
// SoraAccount Sora 账号扩展信息
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
type SoraAccount struct {
AccountID int64 // 关联的 accounts.id
AccessToken string // OAuth access_token
RefreshToken string // OAuth refresh_token
SessionToken string // Session token(可选,用于 ST→AT 兜底)
}
package service
import (
"context"
"fmt"
"net/http"
)
// SoraClient 定义直连 Sora 的任务操作接口。
type SoraClient interface {
Enabled() bool
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
DeletePost(ctx context.Context, account *Account, postID string) error
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
}
// SoraImageRequest 图片生成请求参数
type SoraImageRequest struct {
Prompt string
Width int
Height int
MediaID string
}
// SoraVideoRequest 视频生成请求参数
type SoraVideoRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
VideoCount int
MediaID string
RemixTargetID string
CameoIDs []string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type SoraStoryboardRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
MediaID string
}
// SoraImageTaskStatus 图片任务状态
type SoraImageTaskStatus struct {
ID string
Status string
ProgressPct float64
URLs []string
ErrorMsg string
}
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
ID string
Status string
ProgressPct int
URLs []string
GenerationID string
ErrorMsg string
}
// SoraCameoStatus 角色处理中间态
type SoraCameoStatus struct {
Status string
StatusMessage string
DisplayNameHint string
UsernameHint string
ProfileAssetURL string
InstructionSetHint any
InstructionSet any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type SoraCharacterFinalizeRequest struct {
CameoID string
Username string
DisplayName string
ProfileAssetPointer string
InstructionSet any
}
// SoraUpstreamError 上游错误
type SoraUpstreamError struct {
StatusCode int
Message string
Headers http.Header
Body []byte
}
func (e *SoraUpstreamError) Error() string {
if e == nil {
return "sora upstream error"
}
if e.Message != "" {
return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
}
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
}
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
"math/rand"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
const soraVideoInputMaxBytes = 200 << 20
const soraVideoInputMaxRedirects = 3
const soraVideoInputTimeout = 60 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
"gpt-image-landscape": "540",
"gpt-image-portrait": "540",
}
var soraBlockedHostnames = map[string]struct{}{
"localhost": {},
"localhost.localdomain": {},
"metadata.google.internal": {},
"metadata.google.internal.": {},
}
var soraBlockedCIDRs = mustParseCIDRs([]string{
"0.0.0.0/8",
"10.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.168.0.0/16",
"224.0.0.0/4",
"240.0.0.0/4",
"::/128",
"::1/128",
"fc00::/7",
"fe80::/10",
})
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
soraClient SoraClient
rateLimitService *RateLimitService
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
cfg *config.Config
}
type soraWatermarkOptions struct {
Enabled bool
ParseMethod string
ParseURL string
ParseToken string
FallbackOnFailure bool
DeletePost bool
}
type soraCharacterOptions struct {
SetPublic bool
DeleteAfterGenerate bool
}
type soraCharacterFlowResult struct {
CameoID string
CharacterID string
Username string
DisplayName string
}
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
func NewSoraGatewayService(
soraClient SoraClient,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
soraClient: soraClient,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
if s.httpUpstream == nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
}
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
}
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 上游未配置",
},
})
}
return nil, errors.New("sora upstream not configured")
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream)
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
if strings.TrimSpace(reqModel) == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
return nil, errors.New("model is required")
}
originalModel := reqModel
mappedModel := account.GetMappedModel(reqModel)
var upstreamModel string
if mappedModel != "" && mappedModel != reqModel {
reqModel = mappedModel
upstreamModel = mappedModel
}
modelCfg, ok := GetSoraModelConfig(reqModel)
if !ok {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
prompt = strings.TrimSpace(prompt)
imageInput = strings.TrimSpace(imageInput)
videoInput = strings.TrimSpace(videoInput)
remixTargetID = strings.TrimSpace(remixTargetID)
if videoInput != "" && modelCfg.Type != "video" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
return nil, errors.New("video input only supports video models")
}
if videoInput != "" && imageInput != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
return nil, errors.New("image input and video input cannot be used together")
}
characterOnly := videoInput != "" && prompt == ""
if modelCfg.Type == "prompt_enhance" && prompt == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
}
if modelCfg.Type == "prompt_enhance" {
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
content := strings.TrimSpace(enhancedPrompt)
if content == "" {
content = prompt
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
}
return &ForwardResult{
RequestID: "",
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
characterOpts := parseSoraCharacterOptions(reqBody)
watermarkOpts := parseSoraWatermarkOptions(reqBody)
var characterResult *soraCharacterFlowResult
if videoInput != "" {
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
if videoErr != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
return nil, videoErr
}
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
if videoErr != nil {
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
}
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
characterID := strings.TrimSpace(characterResult.CharacterID)
defer func() {
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
defer cancelCleanup()
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
}
}()
}
if characterOnly {
content := "角色创建成功"
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
resp := buildSoraNonStreamResponse(content, reqModel)
if characterResult != nil {
resp["character_id"] = characterResult.CharacterID
resp["cameo_id"] = characterResult.CameoID
resp["character_username"] = characterResult.Username
resp["character_display_name"] = characterResult.DisplayName
}
c.JSON(http.StatusOK, resp)
}
return &ForwardResult{
RequestID: "",
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
}
}
var imageData []byte
imageFilename := ""
if imageInput != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
return nil, err
}
imageData = decoded
imageFilename = filename
}
mediaID := ""
if len(imageData) > 0 {
uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename)
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
mediaID = uploadID
}
taskID := ""
var err error
videoCount := parseSoraVideoCount(reqBody)
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
Prompt: prompt,
Width: modelCfg.Width,
Height: modelCfg.Height,
MediaID: mediaID,
})
case "video":
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
Prompt: formatSoraStoryboardPrompt(prompt),
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
})
} else {
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
VideoCount: videoCount,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
})
}
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
if clientStream && c != nil {
s.prepareSoraStream(c, taskID)
}
var mediaURLs []string
videoGenerationID := ""
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
switch modelCfg.Type {
case "image":
urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
mediaURLs = urls
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
case "video":
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
if videoStatus != nil {
mediaURLs = videoStatus.URLs
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
}
default:
mediaType = "prompt"
}
watermarkPostID := ""
if modelCfg.Type == "video" && watermarkOpts.Enabled {
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
if watermarkErr != nil {
if !watermarkOpts.FallbackOnFailure {
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
}
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
} else if strings.TrimSpace(watermarkURL) != "" {
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
watermarkPostID = strings.TrimSpace(postID)
}
}
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
}
}
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
response := buildSoraNonStreamResponse(content, reqModel)
if len(finalURLs) > 0 {
response["media_url"] = finalURLs[0]
if len(finalURLs) > 1 {
response["media_urls"] = finalURLs
}
}
c.JSON(http.StatusOK, response)
}
return &ForwardResult{
RequestID: taskID,
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: mediaType,
MediaURL: firstMediaURL(finalURLs),
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if s == nil || s.cfg == nil {
return ctx, nil
}
timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds
if stream {
timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds
}
if timeoutSeconds <= 0 {
return ctx, nil
}
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
opts := soraWatermarkOptions{
Enabled: parseBoolWithDefault(body, "watermark_free", false),
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
}
if opts.ParseMethod == "" {
opts.ParseMethod = "third_party"
}
return opts
}
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
return soraCharacterOptions{
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
}
}
func parseSoraVideoCount(body map[string]any) int {
if body == nil {
return 1
}
keys := []string{"video_count", "videos", "n_variants"}
for _, key := range keys {
count := parseIntWithDefault(body, key, 0)
if count > 0 {
return clampInt(count, 1, 3)
}
}
return 1
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case bool:
return typed
case int:
return typed != 0
case int32:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
case string:
typed = strings.ToLower(strings.TrimSpace(typed))
if typed == "true" || typed == "1" || typed == "yes" {
return true
}
if typed == "false" || typed == "0" || typed == "no" {
return false
}
}
return def
}
func parseStringWithDefault(body map[string]any, key, def string) string {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
if str, ok := val.(string); ok {
return str
}
return def
}
func parseIntWithDefault(body map[string]any, key string, def int) int {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float64:
return int(typed)
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return parsed
}
}
return def
}
func clampInt(v, minVal, maxVal int) int {
if v < minVal {
return minVal
}
if v > maxVal {
return maxVal
}
return v
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
}
raw, ok := body["cameo_ids"]
if !ok {
return nil
}
switch typed := raw.(type) {
case []string:
out := make([]string, 0, len(typed))
for _, item := range typed {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
str, ok := item.(string)
if !ok {
continue
}
str = strings.TrimSpace(str)
if str != "" {
out = append(out, str)
}
}
return out
default:
return nil
}
}
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
if err != nil {
return nil, err
}
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
if err != nil {
return nil, err
}
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
if displayName == "" {
displayName = "Character"
}
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
if profileAssetURL == "" {
return nil, errors.New("profile asset url not found in cameo status")
}
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
if err != nil {
return nil, err
}
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
if err != nil {
return nil, err
}
instructionSet := cameoStatus.InstructionSetHint
if instructionSet == nil {
instructionSet = cameoStatus.InstructionSet
}
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
CameoID: strings.TrimSpace(cameoID),
Username: username,
DisplayName: displayName,
ProfileAssetPointer: assetPointer,
InstructionSet: instructionSet,
})
if err != nil {
return nil, err
}
if opts.SetPublic {
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
return nil, err
}
}
return &soraCharacterFlowResult{
CameoID: strings.TrimSpace(cameoID),
CharacterID: strings.TrimSpace(characterID),
Username: strings.TrimSpace(username),
DisplayName: displayName,
}, nil
}
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
timeout := 10 * time.Minute
interval := 5 * time.Second
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
if maxAttempts < 1 {
maxAttempts = 1
}
var lastErr error
consecutiveErrors := 0
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
if err != nil {
lastErr = err
consecutiveErrors++
if consecutiveErrors >= 3 {
break
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
consecutiveErrors = 0
if status == nil {
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
statusMessage := strings.TrimSpace(status.StatusMessage)
if currentStatus == "failed" {
if statusMessage == "" {
statusMessage = "character creation failed"
}
return nil, errors.New(statusMessage)
}
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
return status, nil
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
}
if lastErr != nil {
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
}
return nil, errors.New("cameo processing timeout")
}
func processSoraCharacterUsername(usernameHint string) string {
usernameHint = strings.TrimSpace(usernameHint)
if usernameHint == "" {
usernameHint = "character"
}
if strings.Contains(usernameHint, ".") {
parts := strings.Split(usernameHint, ".")
usernameHint = strings.TrimSpace(parts[len(parts)-1])
}
if usernameHint == "" {
usernameHint = "character"
}
return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100)
}
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
generationID = strings.TrimSpace(generationID)
if generationID == "" {
return "", "", errors.New("generation id is required for watermark-free mode")
}
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
if err != nil {
return "", "", err
}
postID = strings.TrimSpace(postID)
if postID == "" {
return "", "", errors.New("watermark-free publish returned empty post id")
}
switch opts.ParseMethod {
case "custom":
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
if parseErr != nil {
return "", postID, parseErr
}
return strings.TrimSpace(urlVal), postID, nil
case "", "third_party":
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
default:
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
}
}
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 402, 403, 404, 429, 529:
return true
default:
return statusCode >= 500
}
}
func buildSoraNonStreamResponse(content, model string) map[string]any {
return map[string]any{
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
},
},
}
}
func soraImageSizeFromModel(model string) string {
modelLower := strings.ToLower(model)
if size, ok := soraImageSizeMap[modelLower]; ok {
return size
}
if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") {
return "540"
}
return "360"
}
func soraProErrorMessage(model, upstreamMsg string) string {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
}
if strings.Contains(modelLower, "sora2pro") {
return "当前账号无法使用 Sora Pro 模型,请更换模型或账号"
}
return ""
}
func firstMediaURL(urls []string) string {
if len(urls) == 0 {
return ""
}
return urls[0]
}
func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string {
if path == "" {
return path
}
prefix := "/sora/media"
values := url.Values{}
if rawQuery != "" {
if parsed, err := url.ParseQuery(rawQuery); err == nil {
values = parsed
}
}
signKey := ""
ttlSeconds := 0
if s != nil && s.cfg != nil {
signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey)
ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds
}
values.Del("sig")
values.Del("expires")
signingQuery := values.Encode()
if signKey != "" && ttlSeconds > 0 {
expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix()
signature := SignSoraMediaURL(path, signingQuery, expires, signKey)
if signature != "" {
values.Set("expires", strconv.FormatInt(expires, 10))
values.Set("sig", signature)
prefix = "/sora/media-signed"
}
}
encoded := values.Encode()
if encoded == "" {
return prefix + path
}
return prefix + path + "?" + encoded
}
func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) {
if c == nil {
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if strings.TrimSpace(requestID) != "" {
c.Header("x-request-id", requestID)
}
}
func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) {
if c == nil {
return nil, nil
}
writer := c.Writer
flusher, _ := writer.(http.Flusher)
chunk := map[string]any{
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{
"content": content,
},
},
},
}
encoded, _ := jsonMarshalRaw(chunk)
if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
return nil, err
}
if flusher != nil {
flusher.Flush()
}
ms := int(time.Since(startTime).Milliseconds())
finalChunk := map[string]any{
"id": chunk["id"],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{},
"finish_reason": "stop",
},
},
}
finalEncoded, _ := jsonMarshalRaw(finalChunk)
if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
return &ms, err
}
if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil {
return &ms, err
}
if flusher != nil {
flusher.Flush()
}
return &ms, nil
}
func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) {
if c == nil {
return
}
if stream {
flusher, _ := c.Writer.(http.Flusher)
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
_, _ = fmt.Fprint(c.Writer, errorEvent)
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
return
}
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error {
if err == nil {
return nil
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora",
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
accountID,
model,
upstreamErr.StatusCode,
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
strings.TrimSpace(upstreamErr.Message),
truncateForLog(upstreamErr.Body, 1024),
)
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
var responseHeaders http.Header
if upstreamErr.Headers != nil {
responseHeaders = upstreamErr.Headers.Clone()
}
return &UpstreamFailoverError{
StatusCode: upstreamErr.StatusCode,
ResponseBody: upstreamErr.Body,
ResponseHeaders: responseHeaders,
}
}
msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" {
msg = override
}
s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream)
return err
}
if errors.Is(err, context.DeadlineExceeded) {
s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream)
return err
}
s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream)
return err
}
func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetImageTask(ctx, account, taskID)
if err != nil {
return nil, err
}
switch strings.ToLower(status.Status) {
case "succeeded", "completed":
return status.URLs, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
}
return nil, errors.New("sora image generation failed")
}
if stream {
s.maybeSendPing(c, &lastPing)
}
if err := sleepWithContext(ctx, interval); err != nil {
return nil, err
}
}
return nil, errors.New("sora image generation timeout")
}
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetVideoTask(ctx, account, taskID)
if err != nil {
return nil, err
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
return status, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
}
return nil, errors.New("sora video generation failed")
}
if stream {
s.maybeSendPing(c, &lastPing)
}
if err := sleepWithContext(ctx, interval); err != nil {
return nil, err
}
}
return nil, errors.New("sora video generation timeout")
}
func (s *SoraGatewayService) pollInterval() time.Duration {
if s == nil || s.cfg == nil {
return 2 * time.Second
}
interval := s.cfg.Sora.Client.PollIntervalSeconds
if interval <= 0 {
interval = 2
}
return time.Duration(interval) * time.Second
}
func (s *SoraGatewayService) pollMaxAttempts() int {
if s == nil || s.cfg == nil {
return 600
}
maxAttempts := s.cfg.Sora.Client.MaxPollAttempts
if maxAttempts <= 0 {
maxAttempts = 600
}
return maxAttempts
}
func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) {
if c == nil {
return
}
interval := 10 * time.Second
if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 {
interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second
}
if time.Since(*lastPing) < interval {
return
}
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil {
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
*lastPing = time.Now()
}
}
func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
if len(urls) == 0 {
return urls
}
output := make([]string, 0, len(urls))
for _, raw := range urls {
raw = strings.TrimSpace(raw)
if raw == "" {
continue
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
output = append(output, raw)
continue
}
pathVal := raw
if !strings.HasPrefix(pathVal, "/") {
pathVal = "/" + pathVal
}
output = append(output, s.buildSoraMediaURL(pathVal, ""))
}
return output
}
// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
func jsonMarshalRaw(v any) ([]byte, error) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(v); err != nil {
return nil, err
}
// Encode 会追加换行符,去掉它
b := buf.Bytes()
if len(b) > 0 && b[len(b)-1] == '\n' {
b = b[:len(b)-1]
}
return b, nil
}
func buildSoraContent(mediaType string, urls []string) string {
switch mediaType {
case "image":
parts := make([]string, 0, len(urls))
for _, u := range urls {
parts = append(parts, fmt.Sprintf("![image](%s)", u))
}
return strings.Join(parts, "\n")
case "video":
if len(urls) == 0 {
return ""
}
return fmt.Sprintf("```html\n<video src='%s' controls></video>\n```", urls[0])
default:
return ""
}
}
func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) {
if body == nil {
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
remixTargetID = strings.TrimSpace(v)
}
if v, ok := body["image"].(string); ok {
imageInput = v
}
if v, ok := body["video"].(string); ok {
videoInput = v
}
if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" {
prompt = v
}
if messages, ok := body["messages"].([]any); ok {
builder := strings.Builder{}
for _, raw := range messages {
msg, ok := raw.(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if role != "" && role != "user" {
continue
}
content := msg["content"]
text, img, vid := parseSoraMessageContent(content)
if text != "" {
if builder.Len() > 0 {
_, _ = builder.WriteString("\n")
}
_, _ = builder.WriteString(text)
}
if imageInput == "" && img != "" {
imageInput = img
}
if videoInput == "" && vid != "" {
videoInput = vid
}
}
if prompt == "" {
prompt = builder.String()
}
}
if remixTargetID == "" {
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
}
prompt = cleanRemixLinkFromPrompt(prompt)
return prompt, imageInput, videoInput, remixTargetID
}
func parseSoraMessageContent(content any) (text, imageInput, videoInput string) {
switch val := content.(type) {
case string:
return val, "", ""
case []any:
builder := strings.Builder{}
for _, item := range val {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
t, _ := itemMap["type"].(string)
switch t {
case "text":
if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
if builder.Len() > 0 {
_, _ = builder.WriteString("\n")
}
_, _ = builder.WriteString(txt)
}
case "image_url":
if imageInput == "" {
if urlVal, ok := itemMap["image_url"].(map[string]any); ok {
imageInput = fmt.Sprintf("%v", urlVal["url"])
} else if urlStr, ok := itemMap["image_url"].(string); ok {
imageInput = urlStr
}
}
case "video_url":
if videoInput == "" {
if urlVal, ok := itemMap["video_url"].(map[string]any); ok {
videoInput = fmt.Sprintf("%v", urlVal["url"])
} else if urlStr, ok := itemMap["video_url"].(string); ok {
videoInput = urlStr
}
}
}
}
return builder.String(), imageInput, videoInput
default:
return "", "", ""
}
}
func isSoraStoryboardPrompt(prompt string) bool {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return false
}
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
}
func formatSoraStoryboardPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
if len(matches) == 0 {
return prompt
}
firstBracketPos := strings.Index(prompt, "[")
instructions := ""
if firstBracketPos > 0 {
instructions = strings.TrimSpace(prompt[:firstBracketPos])
}
shots := make([]string, 0, len(matches))
for i, match := range matches {
if len(match) < 3 {
continue
}
duration := strings.TrimSpace(match[1])
scene := strings.TrimSpace(match[2])
if scene == "" {
continue
}
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
}
if len(shots) == 0 {
return prompt
}
timeline := strings.Join(shots, "\n\n")
if instructions == "" {
return timeline
}
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
}
func extractRemixTargetIDFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
}
func cleanRemixLinkFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return prompt
}
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
cleaned = strings.Join(strings.Fields(cleaned), " ")
return strings.TrimSpace(cleaned)
}
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return nil, "", errors.New("empty image input")
}
if strings.HasPrefix(raw, "data:") {
parts := strings.SplitN(raw, ",", 2)
if len(parts) != 2 {
return nil, "", errors.New("invalid data url")
}
meta := parts[0]
payload := parts[1]
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
if err != nil {
return nil, "", err
}
ext := ""
if strings.HasPrefix(meta, "data:") {
metaParts := strings.SplitN(meta[5:], ";", 2)
if len(metaParts) > 0 {
if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 {
ext = exts[0]
}
}
}
filename := "image" + ext
return decoded, filename, nil
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return nil, errors.New("empty video input")
}
if strings.HasPrefix(raw, "data:") {
parts := strings.SplitN(raw, ",", 2)
if len(parts) != 2 {
return nil, errors.New("invalid video data url")
}
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraVideoInput(ctx, raw)
}
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, "", err
}
client := &http.Client{
Timeout: soraImageInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
if err != nil {
return nil, "", err
}
ext := fileExtFromURL(parsed.String())
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
filename := "image" + ext
return data, filename, nil
}
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, err
}
client := &http.Client{
Timeout: soraVideoInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraVideoInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, errors.New("empty video content")
}
return data, nil
}
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
if maxBytes <= 0 {
return nil, errors.New("invalid max bytes limit")
}
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
limited := io.LimitReader(decoder, maxBytes+1)
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
}
return data, nil
}
func validateSoraRemoteURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("empty remote url")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, fmt.Errorf("invalid remote url: %w", err)
}
if err := validateSoraRemoteURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
func validateSoraRemoteURLValue(parsed *url.URL) error {
if parsed == nil {
return errors.New("invalid remote url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return errors.New("only http/https remote url is allowed")
}
if parsed.User != nil {
return errors.New("remote url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return errors.New("remote url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
return errors.New("remote url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
return errors.New("remote url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("resolve remote url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
return errors.New("remote url is not allowed")
}
}
return nil
}
func isSoraBlockedIP(ip net.IP) bool {
if ip == nil {
return true
}
for _, cidr := range soraBlockedCIDRs {
if cidr.Contains(ip) {
return true
}
}
return false
}
func mustParseCIDRs(values []string) []*net.IPNet {
out := make([]*net.IPNet, 0, len(values))
for _, val := range values {
_, cidr, err := net.ParseCIDR(val)
if err != nil {
continue
}
out = append(out, cidr)
}
return out
}
//go:build unit
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var _ SoraClient = (*stubSoraClientForPoll)(nil)
type stubSoraClientForPoll struct {
imageStatus *SoraImageTaskStatus
videoStatus *SoraVideoTaskStatus
imageCalls int
videoCalls int
enhanced string
enhanceErr error
storyboard bool
videoReq SoraVideoRequest
parseErr error
postCalls int
deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
return "", nil
}
func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
s.videoReq = req
return "task-video", nil
}
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
s.storyboard = true
return "task-video", nil
}
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
return &SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
return nil
}
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
return nil
}
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
s.postCalls++
return "s_post", nil
}
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
s.deleteCalls++
return nil
}
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
if s.parseErr != nil {
return "", s.parseErr
}
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
}
return "enhanced prompt", s.enhanceErr
}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++
return s.imageStatus, nil
}
func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
s.videoCalls++
return s.videoStatus, nil
}
func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
client := &stubSoraClientForPoll{
imageStatus: &SoraImageTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/a.png"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
require.NoError(t, err)
require.Equal(t, []string{"https://example.com/a.png"}, urls)
require.Equal(t, 1, client.imageCalls)
}
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
client := &stubSoraClientForPoll{
enhanced: "cinematic prompt",
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
},
},
}
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, "prompt-enhance-short-10s", result.Model)
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
}
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 3, client.videoReq.VideoCount)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, 0, client.videoCalls)
}
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
parseErr: errors.New("parse failed"),
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 0, client.deleteCalls)
}
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 1, client.deleteCalls)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "failed",
ErrorMsg: "reject",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
service := NewSoraGatewayService(client, nil, nil, cfg)
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
cfg := &config.Config{
Gateway: config.GatewayConfig{
SoraMediaSigningKey: "test-key",
SoraMediaSignedURLTTLSeconds: 600,
},
}
service := NewSoraGatewayService(nil, nil, nil, cfg)
url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
require.Contains(t, url, "/sora/media-signed")
require.Contains(t, url, "expires=")
require.Contains(t, url, "sig=")
}
func TestNormalizeSoraMediaURLs_Empty(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
result := svc.normalizeSoraMediaURLs(nil)
require.Empty(t, result)
result = svc.normalizeSoraMediaURLs([]string{})
require.Empty(t, result)
}
func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"}
result := svc.normalizeSoraMediaURLs(urls)
require.Equal(t, urls, result)
}
func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) {
cfg := &config.Config{}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"}
result := svc.normalizeSoraMediaURLs(urls)
require.Len(t, result, 2)
require.Contains(t, result[0], "/sora/media")
require.Contains(t, result[1], "/sora/media")
}
func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"}
result := svc.normalizeSoraMediaURLs(urls)
require.Len(t, result, 2)
}
func TestBuildSoraContent_Image(t *testing.T) {
content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"})
require.Contains(t, content, "![image](https://a.com/1.png)")
require.Contains(t, content, "![image](https://a.com/2.png)")
}
func TestBuildSoraContent_Video(t *testing.T) {
content := buildSoraContent("video", []string{"https://a.com/v.mp4"})
require.Contains(t, content, "<video src='https://a.com/v.mp4'")
}
func TestBuildSoraContent_VideoEmpty(t *testing.T) {
content := buildSoraContent("video", nil)
require.Empty(t, content)
}
func TestBuildSoraContent_Prompt(t *testing.T) {
content := buildSoraContent("prompt", nil)
require.Empty(t, content)
}
func TestSoraImageSizeFromModel(t *testing.T) {
require.Equal(t, "360", soraImageSizeFromModel("gpt-image"))
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-landscape"))
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-portrait"))
require.Equal(t, "540", soraImageSizeFromModel("something-landscape"))
require.Equal(t, "360", soraImageSizeFromModel("unknown-model"))
}
func TestFirstMediaURL(t *testing.T) {
require.Equal(t, "", firstMediaURL(nil))
require.Equal(t, "", firstMediaURL([]string{}))
require.Equal(t, "a", firstMediaURL([]string{"a", "b"}))
}
func TestSoraProErrorMessage(t *testing.T) {
require.Contains(t, soraProErrorMessage("sora2pro-hd", ""), "Pro-HD")
require.Contains(t, soraProErrorMessage("sora2pro", ""), "Pro")
require.Empty(t, soraProErrorMessage("sora-basic", ""))
}
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
body := rec.Body.String()
require.Contains(t, body, "event: error\n")
require.Contains(t, body, "data: [DONE]\n\n")
lines := strings.Split(body, "\n")
require.GreaterOrEqual(t, len(lines), 2)
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "))
data := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
errObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
}
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
sourceHeaders := http.Header{}
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
err := svc.handleSoraRequestError(
context.Background(),
&Account{ID: 1, Platform: PlatformSora},
&SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "forbidden",
Headers: sourceHeaders,
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
},
"sora2-landscape-10s",
nil,
false,
)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.NotNil(t, failoverErr.ResponseHeaders)
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
sourceHeaders.Set("cf-ray", "mutated-after-return")
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
}
func TestShouldFailoverUpstreamError(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc.shouldFailoverUpstreamError(401))
require.True(t, svc.shouldFailoverUpstreamError(404))
require.True(t, svc.shouldFailoverUpstreamError(429))
require.True(t, svc.shouldFailoverUpstreamError(500))
require.True(t, svc.shouldFailoverUpstreamError(502))
require.False(t, svc.shouldFailoverUpstreamError(200))
require.False(t, svc.shouldFailoverUpstreamError(400))
}
func TestWithSoraTimeout_NilService(t *testing.T) {
var svc *SoraGatewayService
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
require.NotNil(t, ctx)
require.Nil(t, cancel)
}
func TestWithSoraTimeout_ZeroTimeout(t *testing.T) {
cfg := &config.Config{}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
require.NotNil(t, ctx)
require.Nil(t, cancel)
}
func TestWithSoraTimeout_PositiveTimeout(t *testing.T) {
cfg := &config.Config{
Gateway: config.GatewayConfig{
SoraRequestTimeoutSeconds: 30,
},
}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
require.NotNil(t, ctx)
require.NotNil(t, cancel)
cancel()
}
func TestPollInterval(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 5,
},
},
}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
require.Equal(t, 5*time.Second, svc.pollInterval())
// 默认值
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc2.pollInterval() > 0)
}
func TestPollMaxAttempts(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
MaxPollAttempts: 100,
},
},
}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
require.Equal(t, 100, svc.pollMaxAttempts())
// 默认值
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc2.pollMaxAttempts() > 0)
}
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
require.Error(t, err)
}
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
encoded := "data:image/png;base64,aGVsbG8="
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
require.NoError(t, err)
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
require.Error(t, err)
require.Nil(t, data)
}
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
body := map[string]any{
"watermark_free": float64(1),
"watermark_fallback_on_failure": float64(0),
}
opts := parseSoraWatermarkOptions(body)
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
func TestParseSoraVideoCount(t *testing.T) {
require.Equal(t, 1, parseSoraVideoCount(nil))
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
}
//nolint:unused
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
const soraRewriteBufferLimit = 2048
type soraStreamingResult struct {
mediaType string
mediaURLs []string
imageCount int
imageSize string
firstTokenMs *int
}
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if c != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
}
}
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
if s.rateLimitService == nil || account == nil || resp == nil {
return
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
upstreamMsg = msg
}
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if c != nil {
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
c.JSON(resp.StatusCode, responsePayload)
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
if len(respBody) > 0 {
var payload map[string]any
if err := json.Unmarshal(respBody, &payload); err == nil {
if errObj, ok := payload["error"].(map[string]any); ok {
if overrideMessage != "" {
errObj["message"] = overrideMessage
}
payload["error"] = errObj
return payload
}
}
}
return map[string]any{
"error": map[string]any{
"type": "upstream_error",
"message": overrideMessage,
},
}
}
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
if resp == nil {
return nil, errors.New("empty response")
}
if clientStream {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
}
w := c.Writer
flusher, _ := w.(http.Flusher)
contentBuilder := strings.Builder{}
var firstTokenMs *int
var upstreamError error
rewriteBuffer := ""
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
sendLine := func(line string) error {
if !clientStream {
return nil
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return err
}
if flusher != nil {
flusher.Flush()
}
return nil
}
for scanner.Scan() {
line := scanner.Text()
if soraSSEDataRe.MatchString(line) {
data := soraSSEDataRe.ReplaceAllString(line, "")
if data == "[DONE]" {
if rewriteBuffer != "" {
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
if err != nil {
return nil, err
}
if flushLine != "" {
if flushContent != "" {
if _, err := contentBuilder.WriteString(flushContent); err != nil {
return nil, err
}
}
if err := sendLine(flushLine); err != nil {
return nil, err
}
}
rewriteBuffer = ""
}
if err := sendLine("data: [DONE]"); err != nil {
return nil, err
}
break
}
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
if errEvent != nil && upstreamError == nil {
upstreamError = errEvent
}
if contentDelta != "" {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
return nil, err
}
}
if err := sendLine(updatedLine); err != nil {
return nil, err
}
continue
}
if err := sendLine(line); err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
if clientStream {
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
if flusher != nil {
flusher.Flush()
}
}
return nil, err
}
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
if clientStream {
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
if flusher != nil {
flusher.Flush()
}
}
return nil, err
}
content := contentBuilder.String()
mediaType, mediaURLs := s.extractSoraMedia(content)
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
mediaType = "prompt"
}
imageSize := ""
imageCount := 0
if mediaType == "image" {
imageSize = soraImageSizeFromModel(originalModel)
imageCount = len(mediaURLs)
}
if upstreamError != nil && !clientStream {
if c != nil {
c.JSON(http.StatusBadGateway, map[string]any{
"error": map[string]any{
"type": "upstream_error",
"message": upstreamError.Error(),
},
})
}
return nil, upstreamError
}
if !clientStream {
response := buildSoraNonStreamResponse(content, originalModel)
if len(mediaURLs) > 0 {
response["media_url"] = mediaURLs[0]
if len(mediaURLs) > 1 {
response["media_urls"] = mediaURLs
}
}
c.JSON(http.StatusOK, response)
}
return &soraStreamingResult{
mediaType: mediaType,
mediaURLs: mediaURLs,
imageCount: imageCount,
imageSize: imageSize,
firstTokenMs: firstTokenMs,
}, nil
}
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
if strings.TrimSpace(data) == "" {
return "data: ", "", nil
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return "data: " + data, "", nil
}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return "data: " + data, "", errors.New(msg)
}
}
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
payload["model"] = originalModel
}
contentDelta, updated := extractSoraContent(payload)
if updated {
var rewritten string
if rewriteBuffer != nil {
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
} else {
rewritten = s.rewriteSoraContent(contentDelta)
}
if rewritten != contentDelta {
applySoraContent(payload, rewritten)
contentDelta = rewritten
}
}
updatedData, err := jsonMarshalRaw(payload)
if err != nil {
return "data: " + data, contentDelta, nil
}
return "data: " + string(updatedData), contentDelta, nil
}
func extractSoraContent(payload map[string]any) (string, bool) {
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
return "", false
}
choice, ok := choices[0].(map[string]any)
if !ok {
return "", false
}
if delta, ok := choice["delta"].(map[string]any); ok {
if content, ok := delta["content"].(string); ok {
return content, true
}
}
if message, ok := choice["message"].(map[string]any); ok {
if content, ok := message["content"].(string); ok {
return content, true
}
}
return "", false
}
func applySoraContent(payload map[string]any, content string) {
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
return
}
choice, ok := choices[0].(map[string]any)
if !ok {
return
}
if delta, ok := choice["delta"].(map[string]any); ok {
delta["content"] = content
choice["delta"] = delta
return
}
if message, ok := choice["message"].(map[string]any); ok {
message["content"] = content
choice["message"] = message
}
}
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
if buffer == nil {
return s.rewriteSoraContent(contentDelta)
}
if contentDelta == "" && *buffer == "" {
return ""
}
combined := *buffer + contentDelta
rewritten := s.rewriteSoraContent(combined)
bufferStart := s.findSoraRewriteBufferStart(rewritten)
if bufferStart < 0 {
*buffer = ""
return rewritten
}
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
bufferStart = len(rewritten) - soraRewriteBufferLimit
}
output := rewritten[:bufferStart]
*buffer = rewritten[bufferStart:]
return output
}
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
minIndex := -1
start := 0
for {
idx := strings.Index(content[start:], "![")
if idx < 0 {
break
}
idx += start
if !hasSoraImageMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + 2
}
lower := strings.ToLower(content)
start = 0
for {
idx := strings.Index(lower[start:], "<video")
if idx < 0 {
break
}
idx += start
if !hasSoraVideoMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + len("<video")
}
return minIndex
}
func hasSoraImageMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func hasSoraVideoMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
if content == "" {
return content
}
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
sub := soraImageMarkdownRe.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
rewritten := s.rewriteSoraURL(sub[1])
if rewritten == sub[1] {
return match
}
return strings.Replace(match, sub[1], rewritten, 1)
})
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
sub := soraVideoHTMLRe.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
rewritten := s.rewriteSoraURL(sub[1])
if rewritten == sub[1] {
return match
}
return strings.Replace(match, sub[1], rewritten, 1)
})
return content
}
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
if buffer == "" {
return "", "", nil
}
rewritten := s.rewriteSoraContent(buffer)
payload := map[string]any{
"choices": []any{
map[string]any{
"delta": map[string]any{
"content": rewritten,
},
"index": 0,
},
},
}
if originalModel != "" {
payload["model"] = originalModel
}
updatedData, err := jsonMarshalRaw(payload)
if err != nil {
return "", "", err
}
return "data: " + string(updatedData), rewritten, nil
}
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
path := parsed.Path
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
return raw
}
return s.buildSoraMediaURL(path, parsed.RawQuery)
}
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
if content == "" {
return "", nil
}
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
return "video", []string{match[1]}
}
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
if len(imageMatches) == 0 {
return "", nil
}
urls := make([]string, 0, len(imageMatches))
for _, match := range imageMatches {
if len(match) > 1 {
urls = append(urls, match[1])
}
}
return "image", urls
}
func isSoraPromptEnhanceModel(model string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
}
package service
import (
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type SoraGeneration struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
MediaType string `json:"media_type"` // video / image
Status string `json:"status"` // pending / generating / completed / failed / cancelled
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
FileSizeBytes int64 `json:"file_size_bytes"`
StorageType string `json:"storage_type"` // s3 / local / upstream / none
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
UpstreamTaskID string `json:"upstream_task_id"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const (
SoraGenStatusPending = "pending"
SoraGenStatusGenerating = "generating"
SoraGenStatusCompleted = "completed"
SoraGenStatusFailed = "failed"
SoraGenStatusCancelled = "cancelled"
)
// Sora 存储类型常量
const (
SoraStorageTypeS3 = "s3"
SoraStorageTypeLocal = "local"
SoraStorageTypeUpstream = "upstream"
SoraStorageTypeNone = "none"
)
// SoraGenerationListParams 查询生成记录的参数。
type SoraGenerationListParams struct {
UserID int64
Status string // 可选筛选
StorageType string // 可选筛选
MediaType string // 可选筛选
Page int
PageSize int
}
// SoraGenerationRepository 生成记录持久化接口。
type SoraGenerationRepository interface {
Create(ctx context.Context, gen *SoraGeneration) error
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
Update(ctx context.Context, gen *SoraGeneration) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
}
package service
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
)
const soraGenerationActiveLimit = 3
type soraGenerationRepoAtomicCreator interface {
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
}
type soraGenerationRepoConditionalUpdater interface {
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type SoraGenerationService struct {
genRepo SoraGenerationRepository
s3Storage *SoraS3Storage
quotaService *SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func NewSoraGenerationService(
genRepo SoraGenerationRepository,
s3Storage *SoraS3Storage,
quotaService *SoraQuotaService,
) *SoraGenerationService {
return &SoraGenerationService{
genRepo: genRepo,
s3Storage: s3Storage,
quotaService: quotaService,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
gen := &SoraGeneration{
UserID: userID,
APIKeyID: apiKeyID,
Model: model,
Prompt: prompt,
MediaType: mediaType,
Status: SoraGenStatusPending,
StorageType: SoraStorageTypeNone,
}
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
if err := atomicCreator.CreatePendingWithLimit(
ctx,
gen,
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
soraGenerationActiveLimit,
); err != nil {
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
return nil, err
}
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
if err := s.genRepo.Create(ctx, gen); err != nil {
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
// MarkGenerating 标记为生成中。
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusGenerating
gen.UpstreamTaskID = upstreamTaskID
return s.genRepo.Update(ctx, gen)
}
// MarkCompleted 标记为已完成。
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusCompleted
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkFailed 标记为失败。
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusFailed
gen.ErrorMessage = errMsg
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkCancelled 标记为已取消。
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationNotActive
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationNotActive
}
gen.Status = SoraGenStatusCancelled
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
func (s *SoraGenerationService) UpdateStorageForCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusCompleted {
return ErrSoraGenerationStateConflict
}
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
return s.genRepo.Update(ctx, gen)
}
// GetByID 获取记录详情(含权限校验)。
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if gen.UserID != userID {
return nil, fmt.Errorf("无权访问此生成记录")
}
return gen, nil
}
// List 查询生成记录列表(分页 + 筛选)。
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
return s.genRepo.List(ctx, params)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.UserID != userID {
return fmt.Errorf("无权删除此生成记录")
}
// 清理 S3 文件
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
}
}
// 释放配额(S3/本地均释放)
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
}
}
return s.genRepo.Delete(ctx, id)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
return nil
}
if len(gen.S3ObjectKeys) == 0 {
return nil
}
urls := make([]string, len(gen.S3ObjectKeys))
var wg sync.WaitGroup
var firstErr error
var errMu sync.Mutex
for idx, key := range gen.S3ObjectKeys {
wg.Add(1)
go func(i int, objectKey string) {
defer wg.Done()
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
if err != nil {
errMu.Lock()
if firstErr == nil {
firstErr = err
}
errMu.Unlock()
return
}
urls[i] = url
}(idx, key)
}
wg.Wait()
if firstErr != nil {
return firstErr
}
gen.MediaURL = urls[0]
gen.MediaURLs = urls
return nil
}
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// ==================== Stub: SoraGenerationRepository ====================
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
type stubGenRepo struct {
gens map[int64]*SoraGeneration
nextID int64
createErr error
getErr error
updateErr error
deleteErr error
listErr error
countErr error
countValue int64
}
func newStubGenRepo() *stubGenRepo {
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
}
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
if r.createErr != nil {
return r.createErr
}
gen.ID = r.nextID
gen.CreatedAt = time.Now()
r.nextID++
r.gens[gen.ID] = gen
return nil
}
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
if r.getErr != nil {
return nil, r.getErr
}
if gen, ok := r.gens[id]; ok {
return gen, nil
}
return nil, fmt.Errorf("not found")
}
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
if r.updateErr != nil {
return r.updateErr
}
r.gens[gen.ID] = gen
return nil
}
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
if r.deleteErr != nil {
return r.deleteErr
}
delete(r.gens, id)
return nil
}
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if r.listErr != nil {
return nil, 0, r.listErr
}
var result []*SoraGeneration
for _, gen := range r.gens {
if gen.UserID != params.UserID {
continue
}
if params.Status != "" && gen.Status != params.Status {
continue
}
if params.StorageType != "" && gen.StorageType != params.StorageType {
continue
}
if params.MediaType != "" && gen.MediaType != params.MediaType {
continue
}
result = append(result, gen)
}
return result, int64(len(result)), nil
}
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
if r.countValue > 0 {
return r.countValue, nil
}
var count int64
statusSet := make(map[string]struct{})
for _, s := range statuses {
statusSet[s] = struct{}{}
}
for _, gen := range r.gens {
if gen.UserID == userID {
if _, ok := statusSet[gen.Status]; ok {
count++
}
}
}
return count, nil
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var _ UserRepository = (*stubUserRepoForQuota)(nil)
type stubUserRepoForQuota struct {
users map[int64]*User
updateErr error
}
func newStubUserRepoForQuota() *stubUserRepoForQuota {
return &stubUserRepoForQuota{users: make(map[int64]*User)}
}
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
if u, ok := r.users[id]; ok {
return u, nil
}
return nil, fmt.Errorf("user not found")
}
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
if r.updateErr != nil {
return r.updateErr
}
r.users[user.ID] = user
return nil
}
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
return nil, nil
}
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubUserRepoForQuota) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
storage := &SoraS3Storage{}
storage.cfg = &SoraS3Settings{
Enabled: true,
Bucket: "test-bucket",
CDNURL: cdnURL,
}
// 需要 non-nil client 使 getClient 命中缓存
storage.client = s3.New(s3.Options{})
return storage
}
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
func newS3StorageFailingDelete() *SoraS3Storage {
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
}
// ==================== CreatePending ====================
func TestCreatePending_Success(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
require.NoError(t, err)
require.Equal(t, int64(1), gen.ID)
require.Equal(t, int64(1), gen.UserID)
require.Equal(t, "sora2-landscape-10s", gen.Model)
require.Equal(t, "一只猫跳舞", gen.Prompt)
require.Equal(t, "video", gen.MediaType)
require.Equal(t, SoraGenStatusPending, gen.Status)
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
require.Nil(t, gen.APIKeyID)
}
func TestCreatePending_WithAPIKeyID(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
apiKeyID := int64(42)
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
require.NoError(t, err)
require.NotNil(t, gen.APIKeyID)
require.Equal(t, int64(42), *gen.APIKeyID)
}
func TestCreatePending_RepoError(t *testing.T) {
repo := newStubGenRepo()
repo.createErr = fmt.Errorf("db write error")
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
require.Error(t, err)
require.Nil(t, gen)
require.Contains(t, err.Error(), "create generation")
}
// ==================== MarkGenerating ====================
func TestMarkGenerating_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
require.NoError(t, err)
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
}
func TestMarkGenerating_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 999, "")
require.Error(t, err)
}
func TestMarkGenerating_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 1, "")
require.Error(t, err)
}
// ==================== MarkCompleted ====================
func TestMarkCompleted_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 1,
"https://cdn.example.com/video.mp4",
[]string{"https://cdn.example.com/video.mp4"},
SoraStorageTypeS3,
[]string{"sora/1/2024/01/01/uuid.mp4"},
1048576,
)
require.NoError(t, err)
gen := repo.gens[1]
require.Equal(t, SoraGenStatusCompleted, gen.Status)
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
require.Equal(t, int64(1048576), gen.FileSizeBytes)
require.NotNil(t, gen.CompletedAt)
}
func TestMarkCompleted_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
require.Error(t, err)
}
func TestMarkCompleted_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
require.Error(t, err)
}
// ==================== MarkFailed ====================
func TestMarkFailed_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
require.NoError(t, err)
gen := repo.gens[1]
require.Equal(t, SoraGenStatusFailed, gen.Status)
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
require.NotNil(t, gen.CompletedAt)
}
func TestMarkFailed_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 999, "error")
require.Error(t, err)
}
func TestMarkFailed_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 1, "err")
require.Error(t, err)
}
// ==================== MarkCancelled ====================
func TestMarkCancelled_Pending(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
require.NotNil(t, repo.gens[1].CompletedAt)
}
func TestMarkCancelled_Generating(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
}
func TestMarkCancelled_Completed(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
}
func TestMarkCancelled_Failed(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
func TestMarkCancelled_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 999)
require.Error(t, err)
}
func TestMarkCancelled_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
// ==================== GetByID ====================
func TestGetByID_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1), gen.ID)
require.Equal(t, "sora2-landscape-10s", gen.Model)
}
func TestGetByID_WrongUser(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 1, 1)
require.Error(t, err)
require.Nil(t, gen)
require.Contains(t, err.Error(), "无权访问")
}
func TestGetByID_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 999, 1)
require.Error(t, err)
require.Nil(t, gen)
}
// ==================== List ====================
func TestList_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
svc := NewSoraGenerationService(repo, nil, nil)
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, gens, 2) // 只有 userID=1 的
require.Equal(t, int64(2), total)
}
func TestList_DefaultPagination(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
require.NoError(t, err)
}
func TestList_MaxPageSize(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// pageSize > 100 → 应限制为 100
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
require.NoError(t, err)
}
func TestList_Error(t *testing.T) {
repo := newStubGenRepo()
repo.listErr = fmt.Errorf("db error")
svc := NewSoraGenerationService(repo, nil, nil)
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
require.Error(t, err)
}
// ==================== Delete ====================
func TestDelete_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDelete_WrongUser(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "无权删除")
}
func TestDelete_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 999, 1)
require.Error(t, err)
}
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // s3Storage 为 nil,跳过清理
}
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // quotaService 为 nil,跳过释放
}
func TestDelete_NonS3NoCleanup(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
}
func TestDelete_DeleteRepoError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
repo.deleteErr = fmt.Errorf("delete failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.Error(t, err)
}
// ==================== CountActiveByUser ====================
func TestCountActiveByUser_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
svc := NewSoraGenerationService(repo, nil, nil)
count, err := svc.CountActiveByUser(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(2), count)
}
func TestCountActiveByUser_NoActive(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
count, err := svc.CountActiveByUser(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(0), count)
}
func TestCountActiveByUser_Error(t *testing.T) {
repo := newStubGenRepo()
repo.countErr = fmt.Errorf("db error")
svc := NewSoraGenerationService(repo, nil, nil)
_, err := svc.CountActiveByUser(context.Background(), 1)
require.Error(t, err)
}
// ==================== ResolveMediaURLs ====================
func TestResolveMediaURLs_NilGen(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
}
func TestResolveMediaURLs_NonS3(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
}
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
}
func TestResolveMediaURLs_Local(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
}
// ==================== 状态流转完整测试 ====================
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// 1. 创建 pending
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
require.NoError(t, err)
require.Equal(t, SoraGenStatusPending, gen.Status)
// 2. 标记 generating
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
require.NoError(t, err)
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
// 3. 标记 completed
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
}
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
require.NoError(t, err)
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
}
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
err := svc.MarkCancelled(context.Background(), gen.ID)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
}
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
err := svc.MarkCancelled(context.Background(), gen.ID)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
}
// ==================== 权限隔离测试 ====================
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
// 用户 2 尝试访问用户 1 的记录
_, err := svc.GetByID(context.Background(), gen.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "无权访问")
}
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
err := svc.Delete(context.Background(), gen.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "无权删除")
}
// ==================== Delete: S3 清理 + 配额释放路径 ====================
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
// 验证 Delete 仍然成功(S3 错误只是记录日志)
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
}
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(repo, s3Storage, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // S3 清理失败不影响删除
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
// 有配额服务时,删除 S3 类型记录会释放配额
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
FileSizeBytes: 1048576, // 1MB
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
// 配额应被释放: 2MB - 1MB = 1MB
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
// S3 清理 + 配额释放同时触发
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"key1"},
FileSizeBytes: 512,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
_, exists := repo.gens[1]
require.False(t, exists)
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
// 本地存储同样需要释放配额
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeLocal,
FileSizeBytes: 1024,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
// FileSizeBytes=0 跳过配额释放
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
FileSizeBytes: 0,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
}
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
}
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{
"sora/1/2024/01/01/img1.png",
"sora/1/2024/01/01/img2.png",
"sora/1/2024/01/01/img3.png",
},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
// 主 URL 更新为第一个 key 的 CDN URL
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
// 多图 URLs 全部更新
require.Len(t, gen.MediaURLs, 3)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
}
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
require.Equal(t, "original", gen.MediaURL) // 不变
}
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
// 使用无 settingService 的 S3 Storage,getClient 会失败
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.Error(t, err) // GetAccessURL 失败应传播错误
}
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{
"sora/1/2024/01/01/img1.png",
"sora/1/2024/01/01/img2.png",
},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
}
package service
import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
// SoraMediaCleanupService 定期清理本地媒体文件
type SoraMediaCleanupService struct {
storage *SoraMediaStorage
cfg *config.Config
cron *cron.Cron
startOnce sync.Once
stopOnce sync.Once
}
func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
return &SoraMediaCleanupService{
storage: storage,
cfg: cfg,
}
}
func (s *SoraMediaCleanupService) Start() {
if s == nil || s.cfg == nil {
return
}
if !s.cfg.Sora.Storage.Cleanup.Enabled {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (disabled)")
return
}
if s.storage == nil || !s.storage.Enabled() {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (storage disabled)")
return
}
s.startOnce.Do(func() {
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
if schedule == "" {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (empty schedule)")
return
}
loc := time.Local
if strings.TrimSpace(s.cfg.Timezone) != "" {
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
loc = parsed
}
}
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
return
}
s.cron = c
s.cron.Start()
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
})
}
func (s *SoraMediaCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.cron != nil {
ctx := s.cron.Stop()
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cron stop timed out")
}
}
})
}
func (s *SoraMediaCleanupService) runCleanup() {
if s.cfg == nil || s.storage == nil {
return
}
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
if retention <= 0 {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] skipped (retention_days=%d)", retention)
return
}
cutoff := time.Now().AddDate(0, 0, -retention)
deleted := 0
roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()}
for _, root := range roots {
if root == "" {
continue
}
_ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
return nil
}
if info.ModTime().Before(cutoff) {
if rmErr := os.Remove(p); rmErr == nil {
deleted++
}
}
return nil
})
}
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cleanup finished, deleted=%d", deleted)
}
//go:build unit
package service
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraMediaCleanupService_RunCleanup_NilCfg(t *testing.T) {
storage := &SoraMediaStorage{}
svc := &SoraMediaCleanupService{storage: storage, cfg: nil}
// 不应 panic
svc.runCleanup()
}
func TestSoraMediaCleanupService_RunCleanup_NilStorage(t *testing.T) {
cfg := &config.Config{}
svc := &SoraMediaCleanupService{storage: nil, cfg: cfg}
// 不应 panic
svc.runCleanup()
}
func TestSoraMediaCleanupService_RunCleanup_ZeroRetention(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
RetentionDays: 0,
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
// retention=0 应跳过清理
svc.runCleanup()
}
func TestSoraMediaCleanupService_Start_NilCfg(t *testing.T) {
svc := NewSoraMediaCleanupService(nil, nil)
svc.Start() // cfg == nil 时应直接返回
}
func TestSoraMediaCleanupService_Start_StorageDisabled(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
},
},
},
}
svc := NewSoraMediaCleanupService(nil, cfg)
svc.Start() // storage == nil 时应直接返回
}
func TestSoraMediaCleanupService_Start_WithTimezone(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Timezone: "Asia/Shanghai",
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "0 3 * * *",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
t.Cleanup(svc.Stop)
}
func TestSoraMediaCleanupService_Start_Disabled(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Cleanup: config.SoraStorageCleanupConfig{
Enabled: false,
},
},
},
}
svc := NewSoraMediaCleanupService(nil, cfg)
svc.Start() // 不应 panic,也不应启动 cron
}
func TestSoraMediaCleanupService_Start_NilSelf(t *testing.T) {
var svc *SoraMediaCleanupService
svc.Start() // 不应 panic
}
func TestSoraMediaCleanupService_Start_EmptySchedule(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start() // 空 schedule 不应启动
}
func TestSoraMediaCleanupService_Start_InvalidSchedule(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "invalid-cron",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start() // 无效 schedule 不应 panic
}
func TestSoraMediaCleanupService_Start_ValidSchedule(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "0 3 * * *",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
t.Cleanup(svc.Stop)
}
func TestSoraMediaCleanupService_Stop_NilSelf(t *testing.T) {
var svc *SoraMediaCleanupService
svc.Stop() // 不应 panic
}
func TestSoraMediaCleanupService_Stop_WithoutStart(t *testing.T) {
svc := NewSoraMediaCleanupService(nil, &config.Config{})
svc.Stop() // cron 未启动时 Stop 不应 panic
}
func TestSoraMediaCleanupService_RunCleanup(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
RetentionDays: 1,
},
},
},
}
storage := NewSoraMediaStorage(cfg)
require.NoError(t, storage.EnsureLocalDirs())
oldImage := filepath.Join(storage.ImageRoot(), "old.png")
newVideo := filepath.Join(storage.VideoRoot(), "new.mp4")
require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644))
require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644))
oldTime := time.Now().Add(-48 * time.Hour)
require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime))
cleanup := NewSoraMediaCleanupService(storage, cfg)
cleanup.runCleanup()
require.NoFileExists(t, oldImage)
require.FileExists(t, newVideo)
}
package service
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strconv"
"strings"
)
// SignSoraMediaURL 生成 Sora 媒体临时签名
func SignSoraMediaURL(path string, query string, expires int64, key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
mac := hmac.New(sha256.New, []byte(key))
if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil {
return ""
}
if _, err := mac.Write([]byte("|")); err != nil {
return ""
}
if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil {
return ""
}
return hex.EncodeToString(mac.Sum(nil))
}
// VerifySoraMediaURL 校验 Sora 媒体签名
func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool {
signature = strings.TrimSpace(signature)
if signature == "" {
return false
}
expected := SignSoraMediaURL(path, query, expires, key)
if expected == "" {
return false
}
return hmac.Equal([]byte(signature), []byte(expected))
}
func buildSoraMediaSignPayload(path string, query string) string {
if strings.TrimSpace(query) == "" {
return path
}
return path + "?" + query
}
package service
import "testing"
func TestSoraMediaSignVerify(t *testing.T) {
key := "test-key"
path := "/tmp/abc.png"
query := "a=1&b=2"
expires := int64(1700000000)
signature := SignSoraMediaURL(path, query, expires, key)
if signature == "" {
t.Fatal("签名为空")
}
if !VerifySoraMediaURL(path, query, expires, signature, key) {
t.Fatal("签名校验失败")
}
if VerifySoraMediaURL(path, "a=1", expires, signature, key) {
t.Fatal("签名参数不同仍然通过")
}
if VerifySoraMediaURL(path, query, expires+1, signature, key) {
t.Fatal("签名过期校验未失败")
}
}
func TestSoraMediaSignWithEmptyKey(t *testing.T) {
signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "")
if signature != "" {
t.Fatalf("空密钥不应生成签名")
}
if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") {
t.Fatalf("空密钥不应通过校验")
}
}
package service
import (
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
)
const (
soraStorageDefaultRoot = "/app/data/sora"
)
// SoraMediaStorage 负责下载并落地 Sora 媒体
type SoraMediaStorage struct {
cfg *config.Config
root string
imageRoot string
videoRoot string
downloadTimeout time.Duration
maxDownloadBytes int64
fallbackToUpstream bool
debug bool
sem chan struct{}
ready bool
}
func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
storage := &SoraMediaStorage{cfg: cfg}
storage.refreshConfig()
if storage.Enabled() {
if err := storage.EnsureLocalDirs(); err != nil {
log.Printf("[SoraStorage] 初始化失败: %v", err)
}
}
return storage
}
func (s *SoraMediaStorage) Enabled() bool {
if s == nil || s.cfg == nil {
return false
}
return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local"
}
func (s *SoraMediaStorage) Root() string {
if s == nil {
return ""
}
return s.root
}
func (s *SoraMediaStorage) ImageRoot() string {
if s == nil {
return ""
}
return s.imageRoot
}
func (s *SoraMediaStorage) VideoRoot() string {
if s == nil {
return ""
}
return s.videoRoot
}
func (s *SoraMediaStorage) refreshConfig() {
if s == nil || s.cfg == nil {
return
}
root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath)
if root == "" {
root = soraStorageDefaultRoot
}
root = filepath.Clean(root)
if !filepath.IsAbs(root) {
if absRoot, err := filepath.Abs(root); err == nil {
root = absRoot
}
}
s.root = root
s.imageRoot = filepath.Join(root, "image")
s.videoRoot = filepath.Join(root, "video")
maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads
if maxConcurrent <= 0 {
maxConcurrent = 4
}
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
if timeoutSeconds <= 0 {
timeoutSeconds = 120
}
s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second
maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes
if maxBytes <= 0 {
maxBytes = 200 << 20
}
s.maxDownloadBytes = maxBytes
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
s.debug = s.cfg.Sora.Storage.Debug
s.sem = make(chan struct{}, maxConcurrent)
}
// EnsureLocalDirs 创建并校验本地目录
func (s *SoraMediaStorage) EnsureLocalDirs() error {
if s == nil || !s.Enabled() {
return nil
}
if err := os.MkdirAll(s.imageRoot, 0o755); err != nil {
return fmt.Errorf("create image dir: %w", err)
}
if err := os.MkdirAll(s.videoRoot, 0o755); err != nil {
return fmt.Errorf("create video dir: %w", err)
}
s.ready = true
return nil
}
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) {
if len(urls) == 0 {
return nil, nil
}
if s == nil || !s.Enabled() {
return urls, nil
}
if !s.ready {
if err := s.EnsureLocalDirs(); err != nil {
return nil, err
}
}
results := make([]string, 0, len(urls))
for _, raw := range urls {
relative, err := s.downloadAndStore(ctx, mediaType, raw)
if err != nil {
if s.fallbackToUpstream {
results = append(results, raw)
continue
}
return nil, err
}
results = append(results, relative)
}
return results, nil
}
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) {
if s == nil || len(paths) == 0 {
return 0, nil
}
var total int64
for _, p := range paths {
localPath, err := s.resolveLocalPath(p)
if err != nil {
continue
}
info, err := os.Stat(localPath)
if err != nil {
if os.IsNotExist(err) {
continue
}
return 0, err
}
if info.Mode().IsRegular() {
total += info.Size()
}
}
return total, nil
}
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error {
if s == nil || len(paths) == 0 {
return nil
}
var lastErr error
for _, p := range paths {
localPath, err := s.resolveLocalPath(p)
if err != nil {
continue
}
if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) {
lastErr = err
}
}
return lastErr
}
func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) {
if s == nil || strings.TrimSpace(relativePath) == "" {
return "", errors.New("empty path")
}
cleaned := path.Clean(relativePath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
return "", errors.New("not a local media path")
}
if strings.TrimSpace(s.root) == "" {
return "", errors.New("storage root not configured")
}
relative := strings.TrimPrefix(cleaned, "/")
return filepath.Join(s.root, filepath.FromSlash(relative)), nil
}
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
if strings.TrimSpace(rawURL) == "" {
return "", errors.New("empty url")
}
root := s.imageRoot
if mediaType == "video" {
root = s.videoRoot
}
if root == "" {
return "", errors.New("storage root not configured")
}
retries := 3
for attempt := 1; attempt <= retries; attempt++ {
release, err := s.acquire(ctx)
if err != nil {
return "", err
}
relative, err := s.downloadOnce(ctx, root, mediaType, rawURL)
release()
if err == nil {
return relative, nil
}
if s.debug {
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err)
}
if attempt < retries {
time.Sleep(time.Duration(attempt*attempt) * time.Second)
continue
}
return "", err
}
return "", errors.New("download retries exhausted")
}
func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", err
}
client := &http.Client{Timeout: s.downloadTimeout}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body))
}
ext := normalizeSoraFileExt(fileExtFromURL(rawURL))
if ext == "" {
ext = normalizeSoraFileExt(fileExtFromContentType(resp.Header.Get("Content-Type")))
}
if ext == "" {
ext = ".bin"
}
if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes {
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
}
storageRoot, err := os.OpenRoot(root)
if err != nil {
return "", err
}
defer func() { _ = storageRoot.Close() }()
datePath := time.Now().Format("2006/01/02")
datePathFS := filepath.FromSlash(datePath)
if err := storageRoot.MkdirAll(datePathFS, 0o755); err != nil {
return "", err
}
filename := uuid.NewString() + ext
filePath := filepath.Join(datePathFS, filename)
out, err := storageRoot.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return "", err
}
defer func() { _ = out.Close() }()
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
written, err := io.Copy(out, limited)
if err != nil {
removePartialDownload(storageRoot, filePath)
return "", err
}
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
removePartialDownload(storageRoot, filePath)
return "", fmt.Errorf("download size exceeds limit: %d", written)
}
relative := path.Join("/", mediaType, datePath, filename)
if s.debug {
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative)
}
return relative, nil
}
func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) {
if s.sem == nil {
return func() {}, nil
}
select {
case s.sem <- struct{}{}:
return func() { <-s.sem }, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func fileExtFromURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
return ""
}
ext := path.Ext(parsed.Path)
return strings.ToLower(ext)
}
func fileExtFromContentType(ct string) string {
if ct == "" {
return ""
}
if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 {
return strings.ToLower(exts[0])
}
return ""
}
func normalizeSoraFileExt(ext string) string {
ext = strings.ToLower(strings.TrimSpace(ext))
switch ext {
case ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg", ".tif", ".tiff", ".heic",
".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
return ext
default:
return ""
}
}
func removePartialDownload(root *os.Root, filePath string) {
if root == nil || strings.TrimSpace(filePath) == "" {
return
}
_ = root.Remove(filePath)
}
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
func sanitizeMediaLogURL(rawURL string) string {
parsed, err := url.Parse(rawURL)
if err != nil {
if len(rawURL) > 80 {
return rawURL[:80] + "..."
}
return rawURL
}
safe := parsed.Scheme + "://" + parsed.Host + parsed.Path
if len(safe) > 120 {
return safe[:120] + "..."
}
return safe
}
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraMediaStorage_StoreFromURLs(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("data"))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
MaxConcurrentDownloads: 1,
},
},
}
storage := NewSoraMediaStorage(cfg)
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.NoError(t, err)
require.Len(t, urls, 1)
require.True(t, strings.HasPrefix(urls[0], "/image/"))
require.True(t, strings.HasSuffix(urls[0], ".png"))
localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/")))
require.FileExists(t, localPath)
}
func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
FallbackToUpstream: true,
},
},
}
storage := NewSoraMediaStorage(cfg)
url := server.URL + "/broken.png"
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url})
require.NoError(t, err)
require.Equal(t, []string{url}, urls)
}
func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("too-large"))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
MaxDownloadBytes: 1,
},
},
}
storage := NewSoraMediaStorage(cfg)
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.Error(t, err)
}
func TestNormalizeSoraFileExt(t *testing.T) {
require.Equal(t, ".png", normalizeSoraFileExt(".PNG"))
require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4"))
require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd"))
require.Equal(t, "", normalizeSoraFileExt(".php"))
}
func TestRemovePartialDownload(t *testing.T) {
tmpDir := t.TempDir()
root, err := os.OpenRoot(tmpDir)
require.NoError(t, err)
defer func() { _ = root.Close() }()
filePath := "partial.bin"
f, err := root.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
require.NoError(t, err)
_, _ = f.WriteString("partial")
_ = f.Close()
removePartialDownload(root, filePath)
_, err = root.Stat(filePath)
require.Error(t, err)
require.True(t, os.IsNotExist(err))
}
package service
import (
"regexp"
"sort"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
// SoraModelConfig Sora 模型配置
type SoraModelConfig struct {
Type string
Width int
Height int
Orientation string
Frames int
Model string
Size string
RequirePro bool
// Prompt-enhance 专用参数
ExpansionLevel string
DurationS int
}
var soraModelConfigs = map[string]SoraModelConfig{
"gpt-image": {
Type: "image",
Width: 360,
Height: 360,
},
"gpt-image-landscape": {
Type: "image",
Width: 540,
Height: 360,
},
"gpt-image-portrait": {
Type: "image",
Width: 360,
Height: 540,
},
"sora2-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_8",
Size: "small",
},
"sora2-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_8",
Size: "small",
},
"sora2-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_8",
Size: "small",
},
"sora2-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_8",
Size: "small",
},
"sora2-landscape-25s": {
Type: "video",
Orientation: "landscape",
Frames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2-portrait-25s": {
Type: "video",
Orientation: "portrait",
Frames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-25s": {
Type: "video",
Orientation: "landscape",
Frames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-25s": {
Type: "video",
Orientation: "portrait",
Frames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-hd-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"prompt-enhance-short-10s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 10,
},
"prompt-enhance-short-15s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 15,
},
"prompt-enhance-short-20s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 20,
},
"prompt-enhance-medium-10s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 10,
},
"prompt-enhance-medium-15s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 15,
},
"prompt-enhance-medium-20s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 20,
},
"prompt-enhance-long-10s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 10,
},
"prompt-enhance-long-15s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 15,
},
"prompt-enhance-long-20s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 20,
},
}
var soraModelIDs = []string{
"gpt-image",
"gpt-image-landscape",
"gpt-image-portrait",
"sora2-landscape-10s",
"sora2-portrait-10s",
"sora2-landscape-15s",
"sora2-portrait-15s",
"sora2-landscape-25s",
"sora2-portrait-25s",
"sora2pro-landscape-10s",
"sora2pro-portrait-10s",
"sora2pro-landscape-15s",
"sora2pro-portrait-15s",
"sora2pro-landscape-25s",
"sora2pro-portrait-25s",
"sora2pro-hd-landscape-10s",
"sora2pro-hd-portrait-10s",
"sora2pro-hd-landscape-15s",
"sora2pro-hd-portrait-15s",
"prompt-enhance-short-10s",
"prompt-enhance-short-15s",
"prompt-enhance-short-20s",
"prompt-enhance-medium-10s",
"prompt-enhance-medium-15s",
"prompt-enhance-medium-20s",
"prompt-enhance-long-10s",
"prompt-enhance-long-15s",
"prompt-enhance-long-20s",
}
// GetSoraModelConfig 返回 Sora 模型配置
func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
key := strings.ToLower(strings.TrimSpace(model))
cfg, ok := soraModelConfigs[key]
return cfg, ok
}
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
type SoraModelFamily struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Orientations []string `json:"orientations"`
Durations []int `json:"durations,omitempty"`
}
var (
videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`)
imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`)
soraFamilyNames = map[string]string{
"sora2": "Sora 2",
"sora2pro": "Sora 2 Pro",
"sora2pro-hd": "Sora 2 Pro HD",
"gpt-image": "GPT Image",
}
)
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
func BuildSoraModelFamilies() []SoraModelFamily {
type familyData struct {
modelType string
orientations map[string]bool
durations map[int]bool
}
families := make(map[string]*familyData)
for id, cfg := range soraModelConfigs {
if cfg.Type == "prompt_enhance" {
continue
}
var famID, orientation string
var duration int
switch cfg.Type {
case "video":
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
famID = id[:len(id)-len(m[0])]
orientation = m[1]
duration, _ = strconv.Atoi(m[2])
}
case "image":
if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
famID = id[:len(id)-len(m[0])]
orientation = m[1]
} else {
famID = id
orientation = "square"
}
}
if famID == "" {
continue
}
fd, ok := families[famID]
if !ok {
fd = &familyData{
modelType: cfg.Type,
orientations: make(map[string]bool),
durations: make(map[int]bool),
}
families[famID] = fd
}
if orientation != "" {
fd.orientations[orientation] = true
}
if duration > 0 {
fd.durations[duration] = true
}
}
// 排序:视频在前、图像在后,同类按名称排序
famIDs := make([]string, 0, len(families))
for id := range families {
famIDs = append(famIDs, id)
}
sort.Slice(famIDs, func(i, j int) bool {
fi, fj := families[famIDs[i]], families[famIDs[j]]
if fi.modelType != fj.modelType {
return fi.modelType == "video"
}
return famIDs[i] < famIDs[j]
})
result := make([]SoraModelFamily, 0, len(famIDs))
for _, famID := range famIDs {
fd := families[famID]
fam := SoraModelFamily{
ID: famID,
Name: soraFamilyNames[famID],
Type: fd.modelType,
}
if fam.Name == "" {
fam.Name = famID
}
for o := range fd.orientations {
fam.Orientations = append(fam.Orientations, o)
}
sort.Strings(fam.Orientations)
for d := range fd.durations {
fam.Durations = append(fam.Durations, d)
}
sort.Ints(fam.Durations)
result = append(result, fam)
}
return result
}
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
// 通过命名约定自动识别视频/图像模型并分组。
func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily {
type familyData struct {
modelType string
orientations map[string]bool
durations map[int]bool
}
families := make(map[string]*familyData)
for _, id := range modelIDs {
id = strings.ToLower(strings.TrimSpace(id))
if id == "" || strings.HasPrefix(id, "prompt-enhance") {
continue
}
var famID, orientation, modelType string
var duration int
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
// 视频模型: {family}-{orientation}-{duration}s
famID = id[:len(id)-len(m[0])]
orientation = m[1]
duration, _ = strconv.Atoi(m[2])
modelType = "video"
} else if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
// 图像模型(带方向): {family}-{orientation}
famID = id[:len(id)-len(m[0])]
orientation = m[1]
modelType = "image"
} else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" {
// 已知的无后缀图像模型(如 gpt-image)
famID = id
orientation = "square"
modelType = "image"
} else if strings.Contains(id, "image") {
// 未知但名称包含 image 的模型,推断为图像模型
famID = id
orientation = "square"
modelType = "image"
} else {
continue
}
if famID == "" {
continue
}
fd, ok := families[famID]
if !ok {
fd = &familyData{
modelType: modelType,
orientations: make(map[string]bool),
durations: make(map[int]bool),
}
families[famID] = fd
}
if orientation != "" {
fd.orientations[orientation] = true
}
if duration > 0 {
fd.durations[duration] = true
}
}
famIDs := make([]string, 0, len(families))
for id := range families {
famIDs = append(famIDs, id)
}
sort.Slice(famIDs, func(i, j int) bool {
fi, fj := families[famIDs[i]], families[famIDs[j]]
if fi.modelType != fj.modelType {
return fi.modelType == "video"
}
return famIDs[i] < famIDs[j]
})
result := make([]SoraModelFamily, 0, len(famIDs))
for _, famID := range famIDs {
fd := families[famID]
fam := SoraModelFamily{
ID: famID,
Name: soraFamilyNames[famID],
Type: fd.modelType,
}
if fam.Name == "" {
fam.Name = famID
}
for o := range fd.orientations {
fam.Orientations = append(fam.Orientations, o)
}
sort.Strings(fam.Orientations)
for d := range fd.durations {
fam.Durations = append(fam.Durations, d)
}
sort.Ints(fam.Durations)
result = append(result, fam)
}
return result
}
// DefaultSoraModels returns the default Sora model list.
func DefaultSoraModels(cfg *config.Config) []openai.Model {
models := make([]openai.Model, 0, len(soraModelIDs))
for _, id := range soraModelIDs {
models = append(models, openai.Model{
ID: id,
Object: "model",
OwnedBy: "openai",
Type: "model",
DisplayName: id,
})
}
if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance {
filtered := models[:0]
for _, model := range models {
if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") {
continue
}
filtered = append(filtered, model)
}
models = filtered
}
return models
}
package service
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraQuotaService 管理 Sora 用户存储配额。
// 配额优先级:用户级 → 分组级 → 系统默认值。
type SoraQuotaService struct {
userRepo UserRepository
groupRepo GroupRepository
settingService *SettingService
}
// NewSoraQuotaService 创建配额服务实例。
func NewSoraQuotaService(
userRepo UserRepository,
groupRepo GroupRepository,
settingService *SettingService,
) *SoraQuotaService {
return &SoraQuotaService{
userRepo: userRepo,
groupRepo: groupRepo,
settingService: settingService,
}
}
// QuotaInfo 返回给客户端的配额信息。
type QuotaInfo struct {
QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制)
UsedBytes int64 `json:"used_bytes"` // 已使用
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0)
QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited
Source string `json:"source,omitempty"` // 兼容旧字段
}
// ErrSoraStorageQuotaExceeded 表示配额不足。
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
// QuotaExceededError 包含配额不足的上下文信息。
type QuotaExceededError struct {
QuotaBytes int64
UsedBytes int64
}
func (e *QuotaExceededError) Error() string {
if e == nil {
return "存储配额不足"
}
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
}
type soraQuotaAtomicUserRepository interface {
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
}
// GetQuota 获取用户的存储配额信息。
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
info := &QuotaInfo{
UsedBytes: user.SoraStorageUsedBytes,
}
// 1. 用户级配额
if user.SoraStorageQuotaBytes > 0 {
info.QuotaBytes = user.SoraStorageQuotaBytes
info.QuotaSource = "user"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 2. 分组级配额(取用户可用分组中最大的配额)
if len(user.AllowedGroups) > 0 {
var maxGroupQuota int64
for _, gid := range user.AllowedGroups {
group, err := s.groupRepo.GetByID(ctx, gid)
if err != nil {
continue
}
if group.SoraStorageQuotaBytes > maxGroupQuota {
maxGroupQuota = group.SoraStorageQuotaBytes
}
}
if maxGroupQuota > 0 {
info.QuotaBytes = maxGroupQuota
info.QuotaSource = "group"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
}
// 3. 系统默认值
defaultQuota := s.getSystemDefaultQuota(ctx)
if defaultQuota > 0 {
info.QuotaBytes = defaultQuota
info.QuotaSource = "system"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 无配额限制
info.QuotaSource = "unlimited"
info.Source = info.QuotaSource
info.AvailableBytes = 0
return info, nil
}
// CheckQuota 检查用户是否有足够的存储配额。
// 返回 nil 表示配额充足或无限制。
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
// 0 表示无限制
if quota.QuotaBytes == 0 {
return nil
}
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return nil
}
// AddUsage 原子累加用量(上传成功后调用)。
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
if err != nil {
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return fmt.Errorf("update user quota usage (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota update: %w", err)
}
user.SoraStorageUsedBytes += bytes
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota usage: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
// ReleaseUsage 释放用量(删除文件后调用)。
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
if err != nil {
return fmt.Errorf("update user quota release (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota release: %w", err)
}
user.SoraStorageUsedBytes -= bytes
if user.SoraStorageUsedBytes < 0 {
user.SoraStorageUsedBytes = 0
}
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota release: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
if quotaBytes <= 0 {
return 0
}
if usedBytes >= quotaBytes {
return 0
}
return quotaBytes - usedBytes
}
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
if s.settingService == nil {
return 0
}
settings, err := s.settingService.GetSoraS3Settings(ctx)
if err != nil {
return 0
}
return settings.DefaultStorageQuotaBytes
}
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
return s.getSystemDefaultQuota(ctx)
}
// SetUserQuota 设置用户级配额(管理员操作)。
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
user, err := userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
user.SoraStorageQuotaBytes = quotaBytes
return userRepo.Update(ctx, user)
}
// ParseQuotaBytes 解析配额字符串为字节数。
func ParseQuotaBytes(s string) int64 {
v, _ := strconv.ParseInt(s, 10, 64)
return v
}
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ==================== Stub: GroupRepository (用于 SoraQuotaService) ====================
var _ GroupRepository = (*stubGroupRepoForQuota)(nil)
type stubGroupRepoForQuota struct {
groups map[int64]*Group
}
func newStubGroupRepoForQuota() *stubGroupRepoForQuota {
return &stubGroupRepoForQuota{groups: make(map[int64]*Group)}
}
func (r *stubGroupRepoForQuota) GetByID(_ context.Context, id int64) (*Group, error) {
if g, ok := r.groups[id]; ok {
return g, nil
}
return nil, fmt.Errorf("group not found")
}
func (r *stubGroupRepoForQuota) Create(context.Context, *Group) error { return nil }
func (r *stubGroupRepoForQuota) GetByIDLite(_ context.Context, id int64) (*Group, error) {
return r.GetByID(context.Background(), id)
}
func (r *stubGroupRepoForQuota) Update(context.Context, *Group) error { return nil }
func (r *stubGroupRepoForQuota) Delete(context.Context, int64) error { return nil }
func (r *stubGroupRepoForQuota) DeleteCascade(context.Context, int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepoForQuota) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepoForQuota) ListActive(context.Context) ([]Group, error) { return nil, nil }
func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([]Group, error) {
return nil, nil
}
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepoForQuota) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepoForQuota) BindAccountsToGroup(context.Context, int64, []int64) error {
return nil
}
func (r *stubGroupRepoForQuota) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
return nil
}
// ==================== Stub: SettingRepository (用于 SettingService) ====================
var _ SettingRepository = (*stubSettingRepoForQuota)(nil)
type stubSettingRepoForQuota struct {
values map[string]string
}
func newStubSettingRepoForQuota(values map[string]string) *stubSettingRepoForQuota {
if values == nil {
values = make(map[string]string)
}
return &stubSettingRepoForQuota{values: values}
}
func (r *stubSettingRepoForQuota) Get(_ context.Context, key string) (*Setting, error) {
if v, ok := r.values[key]; ok {
return &Setting{Key: key, Value: v}, nil
}
return nil, ErrSettingNotFound
}
func (r *stubSettingRepoForQuota) GetValue(_ context.Context, key string) (string, error) {
if v, ok := r.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (r *stubSettingRepoForQuota) Set(_ context.Context, key, value string) error {
r.values[key] = value
return nil
}
func (r *stubSettingRepoForQuota) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string)
for _, k := range keys {
if v, ok := r.values[k]; ok {
result[k] = v
}
}
return result, nil
}
func (r *stubSettingRepoForQuota) SetMultiple(_ context.Context, settings map[string]string) error {
for k, v := range settings {
r.values[k] = v
}
return nil
}
func (r *stubSettingRepoForQuota) GetAll(_ context.Context) (map[string]string, error) {
return r.values, nil
}
func (r *stubSettingRepoForQuota) Delete(_ context.Context, key string) error {
delete(r.values, key)
return nil
}
// ==================== GetQuota ====================
func TestGetQuota_UserLevel(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024, // 10MB
SoraStorageUsedBytes: 3 * 1024 * 1024, // 3MB
}
svc := NewSoraQuotaService(userRepo, nil, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(10*1024*1024), quota.QuotaBytes)
require.Equal(t, int64(3*1024*1024), quota.UsedBytes)
require.Equal(t, "user", quota.Source)
}
func TestGetQuota_GroupLevel(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0, // 用户级无配额
SoraStorageUsedBytes: 1024,
AllowedGroups: []int64{10, 20},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 5 * 1024 * 1024}
groupRepo.groups[20] = &Group{ID: 20, SoraStorageQuotaBytes: 20 * 1024 * 1024}
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) // 取最大值
require.Equal(t, "group", quota.Source)
}
func TestGetQuota_SystemLevel(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 512}
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(userRepo, nil, settingService)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(104857600), quota.QuotaBytes)
require.Equal(t, "system", quota.Source)
}
func TestGetQuota_NoLimit(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 0}
svc := NewSoraQuotaService(userRepo, nil, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(0), quota.QuotaBytes)
require.Equal(t, "unlimited", quota.Source)
}
func TestGetQuota_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
svc := NewSoraQuotaService(userRepo, nil, nil)
_, err := svc.GetQuota(context.Background(), 999)
require.Error(t, err)
require.Contains(t, err.Error(), "get user")
}
func TestGetQuota_GroupRepoError(t *testing.T) {
// 分组获取失败时跳过该分组(不影响整体)
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1, SoraStorageQuotaBytes: 0,
AllowedGroups: []int64{999}, // 不存在的分组
}
groupRepo := newStubGroupRepoForQuota()
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "unlimited", quota.Source) // 分组获取失败,回退到无限制
}
// ==================== CheckQuota ====================
func TestCheckQuota_Sufficient(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 3 * 1024 * 1024,
}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.CheckQuota(context.Background(), 1, 1024)
require.NoError(t, err)
}
func TestCheckQuota_Exceeded(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 10 * 1024 * 1024, // 已满
}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.CheckQuota(context.Background(), 1, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "配额不足")
}
func TestCheckQuota_NoLimit(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0, // 无限制
SoraStorageUsedBytes: 1000000000,
}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.CheckQuota(context.Background(), 1, 999999999)
require.NoError(t, err) // 无限制时始终通过
}
func TestCheckQuota_ExactBoundary(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 1024,
SoraStorageUsedBytes: 1024, // 恰好满
}
svc := NewSoraQuotaService(userRepo, nil, nil)
// 额外 0 字节不超
require.NoError(t, svc.CheckQuota(context.Background(), 1, 0))
// 额外 1 字节超出
require.Error(t, svc.CheckQuota(context.Background(), 1, 1))
}
// ==================== AddUsage ====================
func TestAddUsage_Success(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, 2048)
require.NoError(t, err)
require.Equal(t, int64(3072), userRepo.users[1].SoraStorageUsedBytes)
}
func TestAddUsage_ZeroBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, 0)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestAddUsage_NegativeBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, -100)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestAddUsage_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 999, 1024)
require.Error(t, err)
}
func TestAddUsage_UpdateError(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 0}
userRepo.updateErr = fmt.Errorf("db error")
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, 1024)
require.Error(t, err)
require.Contains(t, err.Error(), "update user quota usage")
}
// ==================== ReleaseUsage ====================
func TestReleaseUsage_Success(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 3072}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 1024)
require.NoError(t, err)
require.Equal(t, int64(2048), userRepo.users[1].SoraStorageUsedBytes)
}
func TestReleaseUsage_ClampToZero(t *testing.T) {
// 释放量大于已用量时,应 clamp 到 0
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 500}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 1000)
require.NoError(t, err)
require.Equal(t, int64(0), userRepo.users[1].SoraStorageUsedBytes)
}
func TestReleaseUsage_ZeroBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 0)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestReleaseUsage_NegativeBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, -50)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestReleaseUsage_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 999, 1024)
require.Error(t, err)
}
func TestReleaseUsage_UpdateError(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
userRepo.updateErr = fmt.Errorf("db error")
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 512)
require.Error(t, err)
require.Contains(t, err.Error(), "update user quota release")
}
// ==================== GetQuotaFromSettings ====================
func TestGetQuotaFromSettings_NilSettingService(t *testing.T) {
svc := NewSoraQuotaService(nil, nil, nil)
require.Equal(t, int64(0), svc.GetQuotaFromSettings(context.Background()))
}
func TestGetQuotaFromSettings_WithSettings(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(nil, nil, settingService)
require.Equal(t, int64(52428800), svc.GetQuotaFromSettings(context.Background()))
}
// ==================== SetUserSoraQuota ====================
func TestSetUserSoraQuota_Success(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0}
err := SetUserSoraQuota(context.Background(), userRepo, 1, 10*1024*1024)
require.NoError(t, err)
require.Equal(t, int64(10*1024*1024), userRepo.users[1].SoraStorageQuotaBytes)
}
func TestSetUserSoraQuota_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
err := SetUserSoraQuota(context.Background(), userRepo, 999, 1024)
require.Error(t, err)
}
// ==================== ParseQuotaBytes ====================
func TestParseQuotaBytes(t *testing.T) {
require.Equal(t, int64(1048576), ParseQuotaBytes("1048576"))
require.Equal(t, int64(0), ParseQuotaBytes(""))
require.Equal(t, int64(0), ParseQuotaBytes("abc"))
require.Equal(t, int64(-1), ParseQuotaBytes("-1"))
}
// ==================== 优先级完整测试 ====================
func TestQuotaPriority_UserOverridesGroup(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 5 * 1024 * 1024,
AllowedGroups: []int64{10},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "user", quota.Source) // 用户级优先
require.Equal(t, int64(5*1024*1024), quota.QuotaBytes)
}
func TestQuotaPriority_GroupOverridesSystem(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0,
AllowedGroups: []int64{10},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "group", quota.Source) // 分组级优先于系统
require.Equal(t, int64(20*1024*1024), quota.QuotaBytes)
}
func TestQuotaPriority_FallbackToSystem(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0,
AllowedGroups: []int64{10},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 0} // 分组无配额
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "system", quota.Source)
require.Equal(t, int64(52428800), quota.QuotaBytes)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment