Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
eb2dce92
Commit
eb2dce92
authored
Apr 06, 2026
by
陈曦
Browse files
升级v1.0.8 解决冲突
parents
7b83d6e7
339d906e
Changes
178
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/openai_token_provider_test.go
View file @
eb2dce92
...
...
@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai
/sora
oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
...
...
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai
/sora
oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
...
...
backend/internal/service/setting_service.go
View file @
eb2dce92
...
...
@@ -22,8 +22,6 @@ import (
var
(
ErrRegistrationDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrSettingNotFound
=
infraerrors
.
NotFound
(
"SETTING_NOT_FOUND"
,
"setting not found"
)
ErrSoraS3ProfileNotFound
=
infraerrors
.
NotFound
(
"SORA_S3_PROFILE_NOT_FOUND"
,
"sora s3 profile not found"
)
ErrSoraS3ProfileExists
=
infraerrors
.
Conflict
(
"SORA_S3_PROFILE_EXISTS"
,
"sora s3 profile already exists"
)
ErrDefaultSubGroupInvalid
=
infraerrors
.
BadRequest
(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID"
,
"default subscription group must exist and be subscription type"
,
...
...
@@ -104,7 +102,6 @@ type SettingService struct {
defaultSubGroupReader
DefaultSubscriptionGroupReader
cfg
*
config
.
Config
onUpdate
func
()
// Callback when settings are updated (for cache invalidation)
onS3Update
func
()
// Callback when Sora S3 settings are updated
version
string
// Application version
}
...
...
@@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton
,
SettingKeyPurchaseSubscriptionEnabled
,
SettingKeyPurchaseSubscriptionURL
,
SettingKeySoraClientEnabled
,
SettingKeyCustomMenuItems
,
SettingKeyCustomEndpoints
,
SettingKeyLinuxDoConnectEnabled
,
...
...
@@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
...
...
@@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s
.
onUpdate
=
callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func
(
s
*
SettingService
)
SetOnS3UpdateCallback
(
callback
func
())
{
s
.
onS3Update
=
callback
}
// SetVersion sets the application version for injection into public settings
func
(
s
*
SettingService
)
SetVersion
(
version
string
)
{
s
.
version
=
version
...
...
@@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url,omitempty"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
CustomMenuItems
json
.
RawMessage
`json:"custom_menu_items"`
CustomEndpoints
json
.
RawMessage
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
...
...
@@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
SoraClientEnabled
:
settings
.
SoraClientEnabled
,
CustomMenuItems
:
filterUserVisibleMenuItems
(
settings
.
CustomMenuItems
),
CustomEndpoints
:
safeRawJSONArray
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
...
...
@@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
updates
[
SettingKeyPurchaseSubscriptionEnabled
]
=
strconv
.
FormatBool
(
settings
.
PurchaseSubscriptionEnabled
)
updates
[
SettingKeyPurchaseSubscriptionURL
]
=
strings
.
TrimSpace
(
settings
.
PurchaseSubscriptionURL
)
updates
[
SettingKeySoraClientEnabled
]
=
strconv
.
FormatBool
(
settings
.
SoraClientEnabled
)
updates
[
SettingKeyCustomMenuItems
]
=
settings
.
CustomMenuItems
updates
[
SettingKeyCustomEndpoints
]
=
settings
.
CustomEndpoints
...
...
@@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeySoraClientEnabled
:
"false"
,
SettingKeyCustomMenuItems
:
"[]"
,
SettingKeyCustomEndpoints
:
"[]"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
...
...
@@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
BackendModeEnabled
:
settings
[
SettingKeyBackendModeEnabled
]
==
"true"
,
...
...
@@ -1583,607 +1568,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyStreamTimeoutSettings
,
string
(
data
))
}
type
soraS3ProfilesStore
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
soraS3ProfileStoreItem
`json:"items"`
}
type
soraS3ProfileStoreItem
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func
(
s
*
SettingService
)
GetSoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
profiles
,
err
:=
s
.
ListSoraS3Profiles
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
activeProfile
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
activeProfile
==
nil
{
return
&
SoraS3Settings
{},
nil
}
return
&
SoraS3Settings
{
Enabled
:
activeProfile
.
Enabled
,
Endpoint
:
activeProfile
.
Endpoint
,
Region
:
activeProfile
.
Region
,
Bucket
:
activeProfile
.
Bucket
,
AccessKeyID
:
activeProfile
.
AccessKeyID
,
SecretAccessKey
:
activeProfile
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
activeProfile
.
SecretAccessKeyConfigured
,
Prefix
:
activeProfile
.
Prefix
,
ForcePathStyle
:
activeProfile
.
ForcePathStyle
,
CDNURL
:
activeProfile
.
CDNURL
,
DefaultStorageQuotaBytes
:
activeProfile
.
DefaultStorageQuotaBytes
,
},
nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func
(
s
*
SettingService
)
SetSoraS3Settings
(
ctx
context
.
Context
,
settings
*
SoraS3Settings
)
error
{
if
settings
==
nil
{
return
fmt
.
Errorf
(
"settings cannot be nil"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
activeIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
store
.
ActiveProfileID
)
if
activeIndex
<
0
{
activeID
:=
"default"
if
hasSoraS3ProfileID
(
store
.
Items
,
activeID
)
{
activeID
=
fmt
.
Sprintf
(
"default-%d"
,
time
.
Now
()
.
Unix
())
}
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
activeID
,
Name
:
"Default"
,
UpdatedAt
:
now
,
})
store
.
ActiveProfileID
=
activeID
activeIndex
=
len
(
store
.
Items
)
-
1
}
active
:=
store
.
Items
[
activeIndex
]
active
.
Enabled
=
settings
.
Enabled
active
.
Endpoint
=
strings
.
TrimSpace
(
settings
.
Endpoint
)
active
.
Region
=
strings
.
TrimSpace
(
settings
.
Region
)
active
.
Bucket
=
strings
.
TrimSpace
(
settings
.
Bucket
)
active
.
AccessKeyID
=
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
active
.
Prefix
=
strings
.
TrimSpace
(
settings
.
Prefix
)
active
.
ForcePathStyle
=
settings
.
ForcePathStyle
active
.
CDNURL
=
strings
.
TrimSpace
(
settings
.
CDNURL
)
active
.
DefaultStorageQuotaBytes
=
maxInt64
(
settings
.
DefaultStorageQuotaBytes
,
0
)
if
settings
.
SecretAccessKey
!=
""
{
active
.
SecretAccessKey
=
settings
.
SecretAccessKey
}
active
.
UpdatedAt
=
now
store
.
Items
[
activeIndex
]
=
active
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func
(
s
*
SettingService
)
ListSoraS3Profiles
(
ctx
context
.
Context
)
(
*
SoraS3ProfileList
,
error
)
{
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
convertSoraS3ProfilesStore
(
store
),
nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func
(
s
*
SettingService
)
CreateSoraS3Profile
(
ctx
context
.
Context
,
profile
*
SoraS3Profile
,
setActive
bool
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
profileID
:=
strings
.
TrimSpace
(
profile
.
ProfileID
)
if
profileID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
if
hasSoraS3ProfileID
(
store
.
Items
,
profileID
)
{
return
nil
,
ErrSoraS3ProfileExists
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
profileID
,
Name
:
name
,
Enabled
:
profile
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
profile
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
profile
.
Region
),
Bucket
:
strings
.
TrimSpace
(
profile
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
profile
.
AccessKeyID
),
SecretAccessKey
:
profile
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
profile
.
Prefix
),
ForcePathStyle
:
profile
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
profile
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
})
if
setActive
||
store
.
ActiveProfileID
==
""
{
store
.
ActiveProfileID
=
profileID
}
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
created
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
profileID
)
if
created
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
created
,
nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func
(
s
*
SettingService
)
UpdateSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
,
profile
*
SoraS3Profile
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
target
:=
store
.
Items
[
targetIndex
]
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
target
.
Name
=
name
target
.
Enabled
=
profile
.
Enabled
target
.
Endpoint
=
strings
.
TrimSpace
(
profile
.
Endpoint
)
target
.
Region
=
strings
.
TrimSpace
(
profile
.
Region
)
target
.
Bucket
=
strings
.
TrimSpace
(
profile
.
Bucket
)
target
.
AccessKeyID
=
strings
.
TrimSpace
(
profile
.
AccessKeyID
)
target
.
Prefix
=
strings
.
TrimSpace
(
profile
.
Prefix
)
target
.
ForcePathStyle
=
profile
.
ForcePathStyle
target
.
CDNURL
=
strings
.
TrimSpace
(
profile
.
CDNURL
)
target
.
DefaultStorageQuotaBytes
=
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
)
if
profile
.
SecretAccessKey
!=
""
{
target
.
SecretAccessKey
=
profile
.
SecretAccessKey
}
target
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
[
targetIndex
]
=
target
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
updated
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
targetID
)
if
updated
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
updated
,
nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func
(
s
*
SettingService
)
DeleteSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
error
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
ErrSoraS3ProfileNotFound
}
store
.
Items
=
append
(
store
.
Items
[
:
targetIndex
],
store
.
Items
[
targetIndex
+
1
:
]
...
)
if
store
.
ActiveProfileID
==
targetID
{
store
.
ActiveProfileID
=
""
if
len
(
store
.
Items
)
>
0
{
store
.
ActiveProfileID
=
store
.
Items
[
0
]
.
ProfileID
}
}
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func
(
s
*
SettingService
)
SetActiveSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
(
*
SoraS3Profile
,
error
)
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
store
.
ActiveProfileID
=
targetID
store
.
Items
[
targetIndex
]
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
active
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
active
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
active
,
nil
}
func
(
s
*
SettingService
)
loadSoraS3ProfilesStore
(
ctx
context
.
Context
)
(
*
soraS3ProfilesStore
,
error
)
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySoraS3Profiles
)
if
err
==
nil
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
&
soraS3ProfilesStore
{},
nil
}
var
store
soraS3ProfilesStore
if
unmarshalErr
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
store
);
unmarshalErr
!=
nil
{
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal sora s3 profiles: %w"
,
unmarshalErr
)
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
normalized
:=
normalizeSoraS3ProfilesStore
(
store
)
return
&
normalized
,
nil
}
if
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"get sora s3 profiles: %w"
,
err
)
}
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
legacyErr
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
func
(
s
*
SettingService
)
persistSoraS3ProfilesStore
(
ctx
context
.
Context
,
store
*
soraS3ProfilesStore
)
error
{
if
store
==
nil
{
return
fmt
.
Errorf
(
"sora s3 profiles store cannot be nil"
)
}
normalized
:=
normalizeSoraS3ProfilesStore
(
*
store
)
data
,
err
:=
json
.
Marshal
(
normalized
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal sora s3 profiles: %w"
,
err
)
}
updates
:=
map
[
string
]
string
{
SettingKeySoraS3Profiles
:
string
(
data
),
}
active
:=
pickActiveSoraS3ProfileFromStore
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
if
active
==
nil
{
updates
[
SettingKeySoraS3Enabled
]
=
"false"
updates
[
SettingKeySoraS3Endpoint
]
=
""
updates
[
SettingKeySoraS3Region
]
=
""
updates
[
SettingKeySoraS3Bucket
]
=
""
updates
[
SettingKeySoraS3AccessKeyID
]
=
""
updates
[
SettingKeySoraS3Prefix
]
=
""
updates
[
SettingKeySoraS3ForcePathStyle
]
=
"false"
updates
[
SettingKeySoraS3CDNURL
]
=
""
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
"0"
updates
[
SettingKeySoraS3SecretAccessKey
]
=
""
}
else
{
updates
[
SettingKeySoraS3Enabled
]
=
strconv
.
FormatBool
(
active
.
Enabled
)
updates
[
SettingKeySoraS3Endpoint
]
=
strings
.
TrimSpace
(
active
.
Endpoint
)
updates
[
SettingKeySoraS3Region
]
=
strings
.
TrimSpace
(
active
.
Region
)
updates
[
SettingKeySoraS3Bucket
]
=
strings
.
TrimSpace
(
active
.
Bucket
)
updates
[
SettingKeySoraS3AccessKeyID
]
=
strings
.
TrimSpace
(
active
.
AccessKeyID
)
updates
[
SettingKeySoraS3Prefix
]
=
strings
.
TrimSpace
(
active
.
Prefix
)
updates
[
SettingKeySoraS3ForcePathStyle
]
=
strconv
.
FormatBool
(
active
.
ForcePathStyle
)
updates
[
SettingKeySoraS3CDNURL
]
=
strings
.
TrimSpace
(
active
.
CDNURL
)
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
strconv
.
FormatInt
(
maxInt64
(
active
.
DefaultStorageQuotaBytes
,
0
),
10
)
updates
[
SettingKeySoraS3SecretAccessKey
]
=
active
.
SecretAccessKey
}
if
err
:=
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
);
err
!=
nil
{
return
err
}
if
s
.
onUpdate
!=
nil
{
s
.
onUpdate
()
}
if
s
.
onS3Update
!=
nil
{
s
.
onS3Update
()
}
return
nil
}
func
(
s
*
SettingService
)
getLegacySoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
keys
:=
[]
string
{
SettingKeySoraS3Enabled
,
SettingKeySoraS3Endpoint
,
SettingKeySoraS3Region
,
SettingKeySoraS3Bucket
,
SettingKeySoraS3AccessKeyID
,
SettingKeySoraS3SecretAccessKey
,
SettingKeySoraS3Prefix
,
SettingKeySoraS3ForcePathStyle
,
SettingKeySoraS3CDNURL
,
SettingKeySoraDefaultStorageQuotaBytes
,
}
values
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get legacy sora s3 settings: %w"
,
err
)
}
result
:=
&
SoraS3Settings
{
Enabled
:
values
[
SettingKeySoraS3Enabled
]
==
"true"
,
Endpoint
:
values
[
SettingKeySoraS3Endpoint
],
Region
:
values
[
SettingKeySoraS3Region
],
Bucket
:
values
[
SettingKeySoraS3Bucket
],
AccessKeyID
:
values
[
SettingKeySoraS3AccessKeyID
],
SecretAccessKey
:
values
[
SettingKeySoraS3SecretAccessKey
],
SecretAccessKeyConfigured
:
values
[
SettingKeySoraS3SecretAccessKey
]
!=
""
,
Prefix
:
values
[
SettingKeySoraS3Prefix
],
ForcePathStyle
:
values
[
SettingKeySoraS3ForcePathStyle
]
==
"true"
,
CDNURL
:
values
[
SettingKeySoraS3CDNURL
],
}
if
v
,
parseErr
:=
strconv
.
ParseInt
(
values
[
SettingKeySoraDefaultStorageQuotaBytes
],
10
,
64
);
parseErr
==
nil
{
result
.
DefaultStorageQuotaBytes
=
v
}
return
result
,
nil
}
func
normalizeSoraS3ProfilesStore
(
store
soraS3ProfilesStore
)
soraS3ProfilesStore
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
store
.
Items
))
normalized
:=
soraS3ProfilesStore
{
ActiveProfileID
:
strings
.
TrimSpace
(
store
.
ActiveProfileID
),
Items
:
make
([]
soraS3ProfileStoreItem
,
0
,
len
(
store
.
Items
)),
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
item
.
ProfileID
=
strings
.
TrimSpace
(
item
.
ProfileID
)
if
item
.
ProfileID
==
""
{
item
.
ProfileID
=
fmt
.
Sprintf
(
"profile-%d"
,
idx
+
1
)
}
if
_
,
exists
:=
seen
[
item
.
ProfileID
];
exists
{
continue
}
seen
[
item
.
ProfileID
]
=
struct
{}{}
item
.
Name
=
strings
.
TrimSpace
(
item
.
Name
)
if
item
.
Name
==
""
{
item
.
Name
=
item
.
ProfileID
}
item
.
Endpoint
=
strings
.
TrimSpace
(
item
.
Endpoint
)
item
.
Region
=
strings
.
TrimSpace
(
item
.
Region
)
item
.
Bucket
=
strings
.
TrimSpace
(
item
.
Bucket
)
item
.
AccessKeyID
=
strings
.
TrimSpace
(
item
.
AccessKeyID
)
item
.
Prefix
=
strings
.
TrimSpace
(
item
.
Prefix
)
item
.
CDNURL
=
strings
.
TrimSpace
(
item
.
CDNURL
)
item
.
DefaultStorageQuotaBytes
=
maxInt64
(
item
.
DefaultStorageQuotaBytes
,
0
)
item
.
UpdatedAt
=
strings
.
TrimSpace
(
item
.
UpdatedAt
)
if
item
.
UpdatedAt
==
""
{
item
.
UpdatedAt
=
now
}
normalized
.
Items
=
append
(
normalized
.
Items
,
item
)
}
if
len
(
normalized
.
Items
)
==
0
{
normalized
.
ActiveProfileID
=
""
return
normalized
}
if
findSoraS3ProfileIndex
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
>=
0
{
return
normalized
}
normalized
.
ActiveProfileID
=
normalized
.
Items
[
0
]
.
ProfileID
return
normalized
}
func
convertSoraS3ProfilesStore
(
store
*
soraS3ProfilesStore
)
*
SoraS3ProfileList
{
if
store
==
nil
{
return
&
SoraS3ProfileList
{}
}
items
:=
make
([]
SoraS3Profile
,
0
,
len
(
store
.
Items
))
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
items
=
append
(
items
,
SoraS3Profile
{
ProfileID
:
item
.
ProfileID
,
Name
:
item
.
Name
,
IsActive
:
item
.
ProfileID
==
store
.
ActiveProfileID
,
Enabled
:
item
.
Enabled
,
Endpoint
:
item
.
Endpoint
,
Region
:
item
.
Region
,
Bucket
:
item
.
Bucket
,
AccessKeyID
:
item
.
AccessKeyID
,
SecretAccessKey
:
item
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
item
.
SecretAccessKey
!=
""
,
Prefix
:
item
.
Prefix
,
ForcePathStyle
:
item
.
ForcePathStyle
,
CDNURL
:
item
.
CDNURL
,
DefaultStorageQuotaBytes
:
item
.
DefaultStorageQuotaBytes
,
UpdatedAt
:
item
.
UpdatedAt
,
})
}
return
&
SoraS3ProfileList
{
ActiveProfileID
:
store
.
ActiveProfileID
,
Items
:
items
,
}
}
func
pickActiveSoraS3Profile
(
items
[]
SoraS3Profile
,
activeProfileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileByID
(
items
[]
SoraS3Profile
,
profileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
&
items
[
idx
]
}
}
return
nil
}
func
pickActiveSoraS3ProfileFromStore
(
items
[]
soraS3ProfileStoreItem
,
activeProfileID
string
)
*
soraS3ProfileStoreItem
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileIndex
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
int
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
idx
}
}
return
-
1
}
func
hasSoraS3ProfileID
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
bool
{
return
findSoraS3ProfileIndex
(
items
,
profileID
)
>=
0
}
func
isEmptyLegacySoraS3Settings
(
settings
*
SoraS3Settings
)
bool
{
if
settings
==
nil
{
return
true
}
if
settings
.
Enabled
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Endpoint
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Region
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Bucket
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
!=
""
{
return
false
}
if
settings
.
SecretAccessKey
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Prefix
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
CDNURL
)
!=
""
{
return
false
}
return
settings
.
DefaultStorageQuotaBytes
==
0
}
func
maxInt64
(
value
int64
,
min
int64
)
int64
{
if
value
<
min
{
return
min
}
return
value
}
backend/internal/service/settings_view.go
View file @
eb2dce92
...
...
@@ -41,7 +41,6 @@ type SystemSettings struct {
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
CustomMenuItems
string
// JSON array of custom menu items
CustomEndpoints
string
// JSON array of custom endpoints
...
...
@@ -107,7 +106,6 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
CustomMenuItems
string
// JSON array of custom menu items
CustomEndpoints
string
// JSON array of custom endpoints
...
...
@@ -116,46 +114,6 @@ type PublicSettings struct {
Version
string
}
// SoraS3Settings Sora S3 存储配置
type
SoraS3Settings
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type
SoraS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"-"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type
SoraS3ProfileList
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
SoraS3Profile
`json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type
StreamTimeoutSettings
struct
{
// Enabled 是否启用流超时处理
...
...
backend/internal/service/sora_account_service.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
"context"
// SoraAccountRepository Sora 账号扩展表仓储接口
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
//
// 设计说明:
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
// - Token 刷新时需要同时更新两个表以保持数据一致性
type
SoraAccountRepository
interface
{
// Upsert 创建或更新 Sora 账号扩展信息
// accountID: 关联的 accounts.id
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
//
// 如果记录不存在则创建,存在则更新。
// 用于:
// 1. 创建 Sora 账号时初始化扩展表
// 2. Token 刷新时同步更新扩展表
Upsert
(
ctx
context
.
Context
,
accountID
int64
,
updates
map
[
string
]
any
)
error
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
// 返回 nil, nil 表示记录不存在(非错误)
GetByAccountID
(
ctx
context
.
Context
,
accountID
int64
)
(
*
SoraAccount
,
error
)
// Delete 删除 Sora 账号扩展信息
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
Delete
(
ctx
context
.
Context
,
accountID
int64
)
error
}
// SoraAccount Sora 账号扩展信息
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
type
SoraAccount
struct
{
AccountID
int64
// 关联的 accounts.id
AccessToken
string
// OAuth access_token
RefreshToken
string
// OAuth refresh_token
SessionToken
string
// Session token(可选,用于 ST→AT 兜底)
}
backend/internal/service/sora_client.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"fmt"
"net/http"
)
// SoraClient 定义直连 Sora 的任务操作接口。
type
SoraClient
interface
{
Enabled
()
bool
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
}
// SoraImageRequest 图片生成请求参数
type
SoraImageRequest
struct
{
Prompt
string
Width
int
Height
int
MediaID
string
}
// SoraVideoRequest 视频生成请求参数
type
SoraVideoRequest
struct
{
Prompt
string
Orientation
string
Frames
int
Model
string
Size
string
VideoCount
int
MediaID
string
RemixTargetID
string
CameoIDs
[]
string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type
SoraStoryboardRequest
struct
{
Prompt
string
Orientation
string
Frames
int
Model
string
Size
string
MediaID
string
}
// SoraImageTaskStatus 图片任务状态
type
SoraImageTaskStatus
struct
{
ID
string
Status
string
ProgressPct
float64
URLs
[]
string
ErrorMsg
string
}
// SoraVideoTaskStatus 视频任务状态
type
SoraVideoTaskStatus
struct
{
ID
string
Status
string
ProgressPct
int
URLs
[]
string
GenerationID
string
ErrorMsg
string
}
// SoraCameoStatus 角色处理中间态
type
SoraCameoStatus
struct
{
Status
string
StatusMessage
string
DisplayNameHint
string
UsernameHint
string
ProfileAssetURL
string
InstructionSetHint
any
InstructionSet
any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type
SoraCharacterFinalizeRequest
struct
{
CameoID
string
Username
string
DisplayName
string
ProfileAssetPointer
string
InstructionSet
any
}
// SoraUpstreamError 上游错误
type
SoraUpstreamError
struct
{
StatusCode
int
Message
string
Headers
http
.
Header
Body
[]
byte
}
func
(
e
*
SoraUpstreamError
)
Error
()
string
{
if
e
==
nil
{
return
"sora upstream error"
}
if
e
.
Message
!=
""
{
return
fmt
.
Sprintf
(
"sora upstream error: %d %s"
,
e
.
StatusCode
,
e
.
Message
)
}
return
fmt
.
Sprintf
(
"sora upstream error: %d"
,
e
.
StatusCode
)
}
backend/internal/service/sora_gateway_service.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
"math/rand"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
const
soraImageInputMaxBytes
=
20
<<
20
const
soraImageInputMaxRedirects
=
3
const
soraImageInputTimeout
=
20
*
time
.
Second
const
soraVideoInputMaxBytes
=
200
<<
20
const
soraVideoInputMaxRedirects
=
3
const
soraVideoInputTimeout
=
60
*
time
.
Second
var
soraImageSizeMap
=
map
[
string
]
string
{
"gpt-image"
:
"360"
,
"gpt-image-landscape"
:
"540"
,
"gpt-image-portrait"
:
"540"
,
}
var
soraBlockedHostnames
=
map
[
string
]
struct
{}{
"localhost"
:
{},
"localhost.localdomain"
:
{},
"metadata.google.internal"
:
{},
"metadata.google.internal."
:
{},
}
var
soraBlockedCIDRs
=
mustParseCIDRs
([]
string
{
"0.0.0.0/8"
,
"10.0.0.0/8"
,
"100.64.0.0/10"
,
"127.0.0.0/8"
,
"169.254.0.0/16"
,
"172.16.0.0/12"
,
"192.168.0.0/16"
,
"224.0.0.0/4"
,
"240.0.0.0/4"
,
"::/128"
,
"::1/128"
,
"fc00::/7"
,
"fe80::/10"
,
})
// SoraGatewayService handles forwarding requests to Sora upstream.
type
SoraGatewayService
struct
{
soraClient
SoraClient
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
// 用于 apikey 类型账号的 HTTP 透传
cfg
*
config
.
Config
}
type
soraWatermarkOptions
struct
{
Enabled
bool
ParseMethod
string
ParseURL
string
ParseToken
string
FallbackOnFailure
bool
DeletePost
bool
}
type
soraCharacterOptions
struct
{
SetPublic
bool
DeleteAfterGenerate
bool
}
type
soraCharacterFlowResult
struct
{
CameoID
string
CharacterID
string
Username
string
DisplayName
string
}
var
soraStoryboardPattern
=
regexp
.
MustCompile
(
`\[\d+(?:\.\d+)?s\]`
)
var
soraStoryboardShotPattern
=
regexp
.
MustCompile
(
`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`
)
var
soraRemixTargetPattern
=
regexp
.
MustCompile
(
`s_[a-f0-9]{32}`
)
var
soraRemixTargetInURLPattern
=
regexp
.
MustCompile
(
`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`
)
type
soraPreflightChecker
interface
{
PreflightCheck
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
,
modelCfg
SoraModelConfig
)
error
}
func
NewSoraGatewayService
(
soraClient
SoraClient
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
cfg
*
config
.
Config
,
)
*
SoraGatewayService
{
return
&
SoraGatewayService
{
soraClient
:
soraClient
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
cfg
:
cfg
,
}
}
func
(
s
*
SoraGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
clientStream
bool
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
if
account
.
Type
==
AccountTypeAPIKey
&&
account
.
GetBaseURL
()
!=
""
{
if
s
.
httpUpstream
==
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"HTTP upstream client not configured"
,
clientStream
)
return
nil
,
errors
.
New
(
"httpUpstream not configured for sora apikey forwarding"
)
}
return
s
.
forwardToUpstream
(
ctx
,
c
,
account
,
body
,
clientStream
,
startTime
)
}
if
s
.
soraClient
==
nil
||
!
s
.
soraClient
.
Enabled
()
{
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 上游未配置"
,
},
})
}
return
nil
,
errors
.
New
(
"sora upstream not configured"
)
}
var
reqBody
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
reqBody
);
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"parse request: %w"
,
err
)
}
reqModel
,
_
:=
reqBody
[
"model"
]
.
(
string
)
reqStream
,
_
:=
reqBody
[
"stream"
]
.
(
bool
)
if
strings
.
TrimSpace
(
reqModel
)
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"model is required"
)
}
originalModel
:=
reqModel
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
var
upstreamModel
string
if
mappedModel
!=
""
&&
mappedModel
!=
reqModel
{
reqModel
=
mappedModel
upstreamModel
=
mappedModel
}
modelCfg
,
ok
:=
GetSoraModelConfig
(
reqModel
)
if
!
ok
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Unsupported Sora model"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"unsupported model: %s"
,
reqModel
)
}
prompt
,
imageInput
,
videoInput
,
remixTargetID
:=
extractSoraInput
(
reqBody
)
prompt
=
strings
.
TrimSpace
(
prompt
)
imageInput
=
strings
.
TrimSpace
(
imageInput
)
videoInput
=
strings
.
TrimSpace
(
videoInput
)
remixTargetID
=
strings
.
TrimSpace
(
remixTargetID
)
if
videoInput
!=
""
&&
modelCfg
.
Type
!=
"video"
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"video input only supports video models"
,
clientStream
)
return
nil
,
errors
.
New
(
"video input only supports video models"
)
}
if
videoInput
!=
""
&&
imageInput
!=
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"image input and video input cannot be used together"
,
clientStream
)
return
nil
,
errors
.
New
(
"image input and video input cannot be used together"
)
}
characterOnly
:=
videoInput
!=
""
&&
prompt
==
""
if
modelCfg
.
Type
==
"prompt_enhance"
&&
prompt
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"prompt is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"prompt is required"
)
}
if
modelCfg
.
Type
!=
"prompt_enhance"
&&
prompt
==
""
&&
!
characterOnly
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"prompt is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"prompt is required"
)
}
reqCtx
,
cancel
:=
s
.
withSoraTimeout
(
ctx
,
reqStream
)
if
cancel
!=
nil
{
defer
cancel
()
}
if
checker
,
ok
:=
s
.
soraClient
.
(
soraPreflightChecker
);
ok
&&
!
characterOnly
{
if
err
:=
checker
.
PreflightCheck
(
reqCtx
,
account
,
reqModel
,
modelCfg
);
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
}
if
modelCfg
.
Type
==
"prompt_enhance"
{
enhancedPrompt
,
err
:=
s
.
soraClient
.
EnhancePrompt
(
reqCtx
,
account
,
prompt
,
modelCfg
.
ExpansionLevel
,
modelCfg
.
DurationS
)
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
content
:=
strings
.
TrimSpace
(
enhancedPrompt
)
if
content
==
""
{
content
=
prompt
}
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusOK
,
buildSoraNonStreamResponse
(
content
,
reqModel
))
}
return
&
ForwardResult
{
RequestID
:
""
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
"prompt"
,
},
nil
}
characterOpts
:=
parseSoraCharacterOptions
(
reqBody
)
watermarkOpts
:=
parseSoraWatermarkOptions
(
reqBody
)
var
characterResult
*
soraCharacterFlowResult
if
videoInput
!=
""
{
videoData
,
videoErr
:=
decodeSoraVideoInput
(
reqCtx
,
videoInput
)
if
videoErr
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
videoErr
.
Error
(),
clientStream
)
return
nil
,
videoErr
}
characterResult
,
videoErr
=
s
.
createCharacterFromVideo
(
reqCtx
,
account
,
videoData
,
characterOpts
)
if
videoErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
videoErr
,
reqModel
,
c
,
clientStream
)
}
if
characterResult
!=
nil
&&
characterOpts
.
DeleteAfterGenerate
&&
strings
.
TrimSpace
(
characterResult
.
CharacterID
)
!=
""
&&
!
characterOnly
{
characterID
:=
strings
.
TrimSpace
(
characterResult
.
CharacterID
)
defer
func
()
{
cleanupCtx
,
cancelCleanup
:=
context
.
WithTimeout
(
context
.
Background
(),
15
*
time
.
Second
)
defer
cancelCleanup
()
if
err
:=
s
.
soraClient
.
DeleteCharacter
(
cleanupCtx
,
account
,
characterID
);
err
!=
nil
{
log
.
Printf
(
"[Sora] cleanup character failed, character_id=%s err=%v"
,
characterID
,
err
)
}
}()
}
if
characterOnly
{
content
:=
"角色创建成功"
if
characterResult
!=
nil
&&
strings
.
TrimSpace
(
characterResult
.
Username
)
!=
""
{
content
=
fmt
.
Sprintf
(
"角色创建成功,角色名@%s"
,
strings
.
TrimSpace
(
characterResult
.
Username
))
}
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
resp
:=
buildSoraNonStreamResponse
(
content
,
reqModel
)
if
characterResult
!=
nil
{
resp
[
"character_id"
]
=
characterResult
.
CharacterID
resp
[
"cameo_id"
]
=
characterResult
.
CameoID
resp
[
"character_username"
]
=
characterResult
.
Username
resp
[
"character_display_name"
]
=
characterResult
.
DisplayName
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
}
return
&
ForwardResult
{
RequestID
:
""
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
"prompt"
,
},
nil
}
if
characterResult
!=
nil
&&
strings
.
TrimSpace
(
characterResult
.
Username
)
!=
""
{
prompt
=
fmt
.
Sprintf
(
"@%s %s"
,
characterResult
.
Username
,
prompt
)
}
}
var
imageData
[]
byte
imageFilename
:=
""
if
imageInput
!=
""
{
decoded
,
filename
,
err
:=
decodeSoraImageInput
(
reqCtx
,
imageInput
)
if
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
(),
clientStream
)
return
nil
,
err
}
imageData
=
decoded
imageFilename
=
filename
}
mediaID
:=
""
if
len
(
imageData
)
>
0
{
uploadID
,
err
:=
s
.
soraClient
.
UploadImage
(
reqCtx
,
account
,
imageData
,
imageFilename
)
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
mediaID
=
uploadID
}
taskID
:=
""
var
err
error
videoCount
:=
parseSoraVideoCount
(
reqBody
)
switch
modelCfg
.
Type
{
case
"image"
:
taskID
,
err
=
s
.
soraClient
.
CreateImageTask
(
reqCtx
,
account
,
SoraImageRequest
{
Prompt
:
prompt
,
Width
:
modelCfg
.
Width
,
Height
:
modelCfg
.
Height
,
MediaID
:
mediaID
,
})
case
"video"
:
if
remixTargetID
==
""
&&
isSoraStoryboardPrompt
(
prompt
)
{
taskID
,
err
=
s
.
soraClient
.
CreateStoryboardTask
(
reqCtx
,
account
,
SoraStoryboardRequest
{
Prompt
:
formatSoraStoryboardPrompt
(
prompt
),
Orientation
:
modelCfg
.
Orientation
,
Frames
:
modelCfg
.
Frames
,
Model
:
modelCfg
.
Model
,
Size
:
modelCfg
.
Size
,
MediaID
:
mediaID
,
})
}
else
{
taskID
,
err
=
s
.
soraClient
.
CreateVideoTask
(
reqCtx
,
account
,
SoraVideoRequest
{
Prompt
:
prompt
,
Orientation
:
modelCfg
.
Orientation
,
Frames
:
modelCfg
.
Frames
,
Model
:
modelCfg
.
Model
,
Size
:
modelCfg
.
Size
,
VideoCount
:
videoCount
,
MediaID
:
mediaID
,
RemixTargetID
:
remixTargetID
,
CameoIDs
:
extractSoraCameoIDs
(
reqBody
),
})
}
default
:
err
=
fmt
.
Errorf
(
"unsupported model type: %s"
,
modelCfg
.
Type
)
}
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
if
clientStream
&&
c
!=
nil
{
s
.
prepareSoraStream
(
c
,
taskID
)
}
var
mediaURLs
[]
string
videoGenerationID
:=
""
mediaType
:=
modelCfg
.
Type
imageCount
:=
0
imageSize
:=
""
switch
modelCfg
.
Type
{
case
"image"
:
urls
,
pollErr
:=
s
.
pollImageTask
(
reqCtx
,
c
,
account
,
taskID
,
clientStream
)
if
pollErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
pollErr
,
reqModel
,
c
,
clientStream
)
}
mediaURLs
=
urls
imageCount
=
len
(
urls
)
imageSize
=
soraImageSizeFromModel
(
reqModel
)
case
"video"
:
videoStatus
,
pollErr
:=
s
.
pollVideoTaskDetailed
(
reqCtx
,
c
,
account
,
taskID
,
clientStream
)
if
pollErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
pollErr
,
reqModel
,
c
,
clientStream
)
}
if
videoStatus
!=
nil
{
mediaURLs
=
videoStatus
.
URLs
videoGenerationID
=
strings
.
TrimSpace
(
videoStatus
.
GenerationID
)
}
default
:
mediaType
=
"prompt"
}
watermarkPostID
:=
""
if
modelCfg
.
Type
==
"video"
&&
watermarkOpts
.
Enabled
{
watermarkURL
,
postID
,
watermarkErr
:=
s
.
resolveWatermarkFreeURL
(
reqCtx
,
account
,
videoGenerationID
,
watermarkOpts
)
if
watermarkErr
!=
nil
{
if
!
watermarkOpts
.
FallbackOnFailure
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
watermarkErr
,
reqModel
,
c
,
clientStream
)
}
log
.
Printf
(
"[Sora] watermark-free fallback to original URL, task_id=%s err=%v"
,
taskID
,
watermarkErr
)
}
else
if
strings
.
TrimSpace
(
watermarkURL
)
!=
""
{
mediaURLs
=
[]
string
{
strings
.
TrimSpace
(
watermarkURL
)}
watermarkPostID
=
strings
.
TrimSpace
(
postID
)
}
}
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs
:=
s
.
normalizeSoraMediaURLs
(
mediaURLs
)
if
watermarkPostID
!=
""
&&
watermarkOpts
.
DeletePost
{
if
deleteErr
:=
s
.
soraClient
.
DeletePost
(
reqCtx
,
account
,
watermarkPostID
);
deleteErr
!=
nil
{
log
.
Printf
(
"[Sora] delete post failed, post_id=%s err=%v"
,
watermarkPostID
,
deleteErr
)
}
}
content
:=
buildSoraContent
(
mediaType
,
finalURLs
)
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
response
:=
buildSoraNonStreamResponse
(
content
,
reqModel
)
if
len
(
finalURLs
)
>
0
{
response
[
"media_url"
]
=
finalURLs
[
0
]
if
len
(
finalURLs
)
>
1
{
response
[
"media_urls"
]
=
finalURLs
}
}
c
.
JSON
(
http
.
StatusOK
,
response
)
}
return
&
ForwardResult
{
RequestID
:
taskID
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
mediaType
,
MediaURL
:
firstMediaURL
(
finalURLs
),
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
func
(
s
*
SoraGatewayService
)
withSoraTimeout
(
ctx
context
.
Context
,
stream
bool
)
(
context
.
Context
,
context
.
CancelFunc
)
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
ctx
,
nil
}
timeoutSeconds
:=
s
.
cfg
.
Gateway
.
SoraRequestTimeoutSeconds
if
stream
{
timeoutSeconds
=
s
.
cfg
.
Gateway
.
SoraStreamTimeoutSeconds
}
if
timeoutSeconds
<=
0
{
return
ctx
,
nil
}
return
context
.
WithTimeout
(
ctx
,
time
.
Duration
(
timeoutSeconds
)
*
time
.
Second
)
}
func
parseSoraWatermarkOptions
(
body
map
[
string
]
any
)
soraWatermarkOptions
{
opts
:=
soraWatermarkOptions
{
Enabled
:
parseBoolWithDefault
(
body
,
"watermark_free"
,
false
),
ParseMethod
:
strings
.
ToLower
(
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_method"
,
"third_party"
))),
ParseURL
:
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_url"
,
""
)),
ParseToken
:
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_token"
,
""
)),
FallbackOnFailure
:
parseBoolWithDefault
(
body
,
"watermark_fallback_on_failure"
,
true
),
DeletePost
:
parseBoolWithDefault
(
body
,
"watermark_delete_post"
,
false
),
}
if
opts
.
ParseMethod
==
""
{
opts
.
ParseMethod
=
"third_party"
}
return
opts
}
func
parseSoraCharacterOptions
(
body
map
[
string
]
any
)
soraCharacterOptions
{
return
soraCharacterOptions
{
SetPublic
:
parseBoolWithDefault
(
body
,
"character_set_public"
,
true
),
DeleteAfterGenerate
:
parseBoolWithDefault
(
body
,
"character_delete_after_generate"
,
true
),
}
}
func
parseSoraVideoCount
(
body
map
[
string
]
any
)
int
{
if
body
==
nil
{
return
1
}
keys
:=
[]
string
{
"video_count"
,
"videos"
,
"n_variants"
}
for
_
,
key
:=
range
keys
{
count
:=
parseIntWithDefault
(
body
,
key
,
0
)
if
count
>
0
{
return
clampInt
(
count
,
1
,
3
)
}
}
return
1
}
func
parseBoolWithDefault
(
body
map
[
string
]
any
,
key
string
,
def
bool
)
bool
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
switch
typed
:=
val
.
(
type
)
{
case
bool
:
return
typed
case
int
:
return
typed
!=
0
case
int32
:
return
typed
!=
0
case
int64
:
return
typed
!=
0
case
float64
:
return
typed
!=
0
case
string
:
typed
=
strings
.
ToLower
(
strings
.
TrimSpace
(
typed
))
if
typed
==
"true"
||
typed
==
"1"
||
typed
==
"yes"
{
return
true
}
if
typed
==
"false"
||
typed
==
"0"
||
typed
==
"no"
{
return
false
}
}
return
def
}
func
parseStringWithDefault
(
body
map
[
string
]
any
,
key
,
def
string
)
string
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
if
str
,
ok
:=
val
.
(
string
);
ok
{
return
str
}
return
def
}
func
parseIntWithDefault
(
body
map
[
string
]
any
,
key
string
,
def
int
)
int
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
switch
typed
:=
val
.
(
type
)
{
case
int
:
return
typed
case
int32
:
return
int
(
typed
)
case
int64
:
return
int
(
typed
)
case
float64
:
return
int
(
typed
)
case
string
:
parsed
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
typed
))
if
err
==
nil
{
return
parsed
}
}
return
def
}
func
clampInt
(
v
,
minVal
,
maxVal
int
)
int
{
if
v
<
minVal
{
return
minVal
}
if
v
>
maxVal
{
return
maxVal
}
return
v
}
func
extractSoraCameoIDs
(
body
map
[
string
]
any
)
[]
string
{
if
body
==
nil
{
return
nil
}
raw
,
ok
:=
body
[
"cameo_ids"
]
if
!
ok
{
return
nil
}
switch
typed
:=
raw
.
(
type
)
{
case
[]
string
:
out
:=
make
([]
string
,
0
,
len
(
typed
))
for
_
,
item
:=
range
typed
{
item
=
strings
.
TrimSpace
(
item
)
if
item
!=
""
{
out
=
append
(
out
,
item
)
}
}
return
out
case
[]
any
:
out
:=
make
([]
string
,
0
,
len
(
typed
))
for
_
,
item
:=
range
typed
{
str
,
ok
:=
item
.
(
string
)
if
!
ok
{
continue
}
str
=
strings
.
TrimSpace
(
str
)
if
str
!=
""
{
out
=
append
(
out
,
str
)
}
}
return
out
default
:
return
nil
}
}
func
(
s
*
SoraGatewayService
)
createCharacterFromVideo
(
ctx
context
.
Context
,
account
*
Account
,
videoData
[]
byte
,
opts
soraCharacterOptions
)
(
*
soraCharacterFlowResult
,
error
)
{
cameoID
,
err
:=
s
.
soraClient
.
UploadCharacterVideo
(
ctx
,
account
,
videoData
)
if
err
!=
nil
{
return
nil
,
err
}
cameoStatus
,
err
:=
s
.
pollCameoStatus
(
ctx
,
account
,
cameoID
)
if
err
!=
nil
{
return
nil
,
err
}
username
:=
processSoraCharacterUsername
(
cameoStatus
.
UsernameHint
)
displayName
:=
strings
.
TrimSpace
(
cameoStatus
.
DisplayNameHint
)
if
displayName
==
""
{
displayName
=
"Character"
}
profileAssetURL
:=
strings
.
TrimSpace
(
cameoStatus
.
ProfileAssetURL
)
if
profileAssetURL
==
""
{
return
nil
,
errors
.
New
(
"profile asset url not found in cameo status"
)
}
avatarData
,
err
:=
s
.
soraClient
.
DownloadCharacterImage
(
ctx
,
account
,
profileAssetURL
)
if
err
!=
nil
{
return
nil
,
err
}
assetPointer
,
err
:=
s
.
soraClient
.
UploadCharacterImage
(
ctx
,
account
,
avatarData
)
if
err
!=
nil
{
return
nil
,
err
}
instructionSet
:=
cameoStatus
.
InstructionSetHint
if
instructionSet
==
nil
{
instructionSet
=
cameoStatus
.
InstructionSet
}
characterID
,
err
:=
s
.
soraClient
.
FinalizeCharacter
(
ctx
,
account
,
SoraCharacterFinalizeRequest
{
CameoID
:
strings
.
TrimSpace
(
cameoID
),
Username
:
username
,
DisplayName
:
displayName
,
ProfileAssetPointer
:
assetPointer
,
InstructionSet
:
instructionSet
,
})
if
err
!=
nil
{
return
nil
,
err
}
if
opts
.
SetPublic
{
if
err
:=
s
.
soraClient
.
SetCharacterPublic
(
ctx
,
account
,
cameoID
);
err
!=
nil
{
return
nil
,
err
}
}
return
&
soraCharacterFlowResult
{
CameoID
:
strings
.
TrimSpace
(
cameoID
),
CharacterID
:
strings
.
TrimSpace
(
characterID
),
Username
:
strings
.
TrimSpace
(
username
),
DisplayName
:
displayName
,
},
nil
}
func
(
s
*
SoraGatewayService
)
pollCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
timeout
:=
10
*
time
.
Minute
interval
:=
5
*
time
.
Second
maxAttempts
:=
int
(
math
.
Ceil
(
timeout
.
Seconds
()
/
interval
.
Seconds
()))
if
maxAttempts
<
1
{
maxAttempts
=
1
}
var
lastErr
error
consecutiveErrors
:=
0
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetCameoStatus
(
ctx
,
account
,
cameoID
)
if
err
!=
nil
{
lastErr
=
err
consecutiveErrors
++
if
consecutiveErrors
>=
3
{
break
}
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
continue
}
consecutiveErrors
=
0
if
status
==
nil
{
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
continue
}
currentStatus
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
status
.
Status
))
statusMessage
:=
strings
.
TrimSpace
(
status
.
StatusMessage
)
if
currentStatus
==
"failed"
{
if
statusMessage
==
""
{
statusMessage
=
"character creation failed"
}
return
nil
,
errors
.
New
(
statusMessage
)
}
if
strings
.
EqualFold
(
statusMessage
,
"Completed"
)
||
currentStatus
==
"finalized"
{
return
status
,
nil
}
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
}
if
lastErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"poll cameo status failed: %w"
,
lastErr
)
}
return
nil
,
errors
.
New
(
"cameo processing timeout"
)
}
func
processSoraCharacterUsername
(
usernameHint
string
)
string
{
usernameHint
=
strings
.
TrimSpace
(
usernameHint
)
if
usernameHint
==
""
{
usernameHint
=
"character"
}
if
strings
.
Contains
(
usernameHint
,
"."
)
{
parts
:=
strings
.
Split
(
usernameHint
,
"."
)
usernameHint
=
strings
.
TrimSpace
(
parts
[
len
(
parts
)
-
1
])
}
if
usernameHint
==
""
{
usernameHint
=
"character"
}
return
fmt
.
Sprintf
(
"%s%d"
,
usernameHint
,
rand
.
Intn
(
900
)
+
100
)
}
func
(
s
*
SoraGatewayService
)
resolveWatermarkFreeURL
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
,
opts
soraWatermarkOptions
)
(
string
,
string
,
error
)
{
generationID
=
strings
.
TrimSpace
(
generationID
)
if
generationID
==
""
{
return
""
,
""
,
errors
.
New
(
"generation id is required for watermark-free mode"
)
}
postID
,
err
:=
s
.
soraClient
.
PostVideoForWatermarkFree
(
ctx
,
account
,
generationID
)
if
err
!=
nil
{
return
""
,
""
,
err
}
postID
=
strings
.
TrimSpace
(
postID
)
if
postID
==
""
{
return
""
,
""
,
errors
.
New
(
"watermark-free publish returned empty post id"
)
}
switch
opts
.
ParseMethod
{
case
"custom"
:
urlVal
,
parseErr
:=
s
.
soraClient
.
GetWatermarkFreeURLCustom
(
ctx
,
account
,
opts
.
ParseURL
,
opts
.
ParseToken
,
postID
)
if
parseErr
!=
nil
{
return
""
,
postID
,
parseErr
}
return
strings
.
TrimSpace
(
urlVal
),
postID
,
nil
case
""
,
"third_party"
:
return
fmt
.
Sprintf
(
"https://oscdn2.dyysy.com/MP4/%s.mp4"
,
postID
),
postID
,
nil
default
:
return
""
,
postID
,
fmt
.
Errorf
(
"unsupported watermark parse method: %s"
,
opts
.
ParseMethod
)
}
}
func
(
s
*
SoraGatewayService
)
shouldFailoverUpstreamError
(
statusCode
int
)
bool
{
switch
statusCode
{
case
401
,
402
,
403
,
404
,
429
,
529
:
return
true
default
:
return
statusCode
>=
500
}
}
func
buildSoraNonStreamResponse
(
content
,
model
string
)
map
[
string
]
any
{
return
map
[
string
]
any
{
"id"
:
fmt
.
Sprintf
(
"chatcmpl-%d"
,
time
.
Now
()
.
UnixNano
()),
"object"
:
"chat.completion"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"message"
:
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
content
,
},
"finish_reason"
:
"stop"
,
},
},
}
}
func
soraImageSizeFromModel
(
model
string
)
string
{
modelLower
:=
strings
.
ToLower
(
model
)
if
size
,
ok
:=
soraImageSizeMap
[
modelLower
];
ok
{
return
size
}
if
strings
.
Contains
(
modelLower
,
"landscape"
)
||
strings
.
Contains
(
modelLower
,
"portrait"
)
{
return
"540"
}
return
"360"
}
func
soraProErrorMessage
(
model
,
upstreamMsg
string
)
string
{
modelLower
:=
strings
.
ToLower
(
model
)
if
strings
.
Contains
(
modelLower
,
"sora2pro-hd"
)
{
return
"当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
}
if
strings
.
Contains
(
modelLower
,
"sora2pro"
)
{
return
"当前账号无法使用 Sora Pro 模型,请更换模型或账号"
}
return
""
}
func
firstMediaURL
(
urls
[]
string
)
string
{
if
len
(
urls
)
==
0
{
return
""
}
return
urls
[
0
]
}
func
(
s
*
SoraGatewayService
)
buildSoraMediaURL
(
path
string
,
rawQuery
string
)
string
{
if
path
==
""
{
return
path
}
prefix
:=
"/sora/media"
values
:=
url
.
Values
{}
if
rawQuery
!=
""
{
if
parsed
,
err
:=
url
.
ParseQuery
(
rawQuery
);
err
==
nil
{
values
=
parsed
}
}
signKey
:=
""
ttlSeconds
:=
0
if
s
!=
nil
&&
s
.
cfg
!=
nil
{
signKey
=
strings
.
TrimSpace
(
s
.
cfg
.
Gateway
.
SoraMediaSigningKey
)
ttlSeconds
=
s
.
cfg
.
Gateway
.
SoraMediaSignedURLTTLSeconds
}
values
.
Del
(
"sig"
)
values
.
Del
(
"expires"
)
signingQuery
:=
values
.
Encode
()
if
signKey
!=
""
&&
ttlSeconds
>
0
{
expires
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
ttlSeconds
)
*
time
.
Second
)
.
Unix
()
signature
:=
SignSoraMediaURL
(
path
,
signingQuery
,
expires
,
signKey
)
if
signature
!=
""
{
values
.
Set
(
"expires"
,
strconv
.
FormatInt
(
expires
,
10
))
values
.
Set
(
"sig"
,
signature
)
prefix
=
"/sora/media-signed"
}
}
encoded
:=
values
.
Encode
()
if
encoded
==
""
{
return
prefix
+
path
}
return
prefix
+
path
+
"?"
+
encoded
}
func
(
s
*
SoraGatewayService
)
prepareSoraStream
(
c
*
gin
.
Context
,
requestID
string
)
{
if
c
==
nil
{
return
}
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
strings
.
TrimSpace
(
requestID
)
!=
""
{
c
.
Header
(
"x-request-id"
,
requestID
)
}
}
func
(
s
*
SoraGatewayService
)
writeSoraStream
(
c
*
gin
.
Context
,
model
,
content
string
,
startTime
time
.
Time
)
(
*
int
,
error
)
{
if
c
==
nil
{
return
nil
,
nil
}
writer
:=
c
.
Writer
flusher
,
_
:=
writer
.
(
http
.
Flusher
)
chunk
:=
map
[
string
]
any
{
"id"
:
fmt
.
Sprintf
(
"chatcmpl-%d"
,
time
.
Now
()
.
UnixNano
()),
"object"
:
"chat.completion.chunk"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
map
[
string
]
any
{
"content"
:
content
,
},
},
},
}
encoded
,
_
:=
jsonMarshalRaw
(
chunk
)
if
_
,
err
:=
fmt
.
Fprintf
(
writer
,
"data: %s
\n\n
"
,
encoded
);
err
!=
nil
{
return
nil
,
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
finalChunk
:=
map
[
string
]
any
{
"id"
:
chunk
[
"id"
],
"object"
:
"chat.completion.chunk"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
map
[
string
]
any
{},
"finish_reason"
:
"stop"
,
},
},
}
finalEncoded
,
_
:=
jsonMarshalRaw
(
finalChunk
)
if
_
,
err
:=
fmt
.
Fprintf
(
writer
,
"data: %s
\n\n
"
,
finalEncoded
);
err
!=
nil
{
return
&
ms
,
err
}
if
_
,
err
:=
fmt
.
Fprint
(
writer
,
"data: [DONE]
\n\n
"
);
err
!=
nil
{
return
&
ms
,
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
&
ms
,
nil
}
func
(
s
*
SoraGatewayService
)
writeSoraError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
,
stream
bool
)
{
if
c
==
nil
{
return
}
if
stream
{
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
errorData
:=
map
[
string
]
any
{
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"event: error
\n
data: %s
\n\n
"
,
string
(
jsonBytes
))
_
,
_
=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
)
_
,
_
=
fmt
.
Fprint
(
c
.
Writer
,
"data: [DONE]
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
message
,
},
})
}
func
(
s
*
SoraGatewayService
)
handleSoraRequestError
(
ctx
context
.
Context
,
account
*
Account
,
err
error
,
model
string
,
c
*
gin
.
Context
,
stream
bool
)
error
{
if
err
==
nil
{
return
nil
}
var
upstreamErr
*
SoraUpstreamError
if
errors
.
As
(
err
,
&
upstreamErr
)
{
accountID
:=
int64
(
0
)
if
account
!=
nil
{
accountID
=
account
.
ID
}
logger
.
LegacyPrintf
(
"service.sora"
,
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s"
,
accountID
,
model
,
upstreamErr
.
StatusCode
,
strings
.
TrimSpace
(
upstreamErr
.
Headers
.
Get
(
"x-request-id"
)),
strings
.
TrimSpace
(
upstreamErr
.
Headers
.
Get
(
"cf-ray"
)),
strings
.
TrimSpace
(
upstreamErr
.
Message
),
truncateForLog
(
upstreamErr
.
Body
,
1024
),
)
if
s
.
rateLimitService
!=
nil
&&
account
!=
nil
{
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
upstreamErr
.
StatusCode
,
upstreamErr
.
Headers
,
upstreamErr
.
Body
)
}
if
s
.
shouldFailoverUpstreamError
(
upstreamErr
.
StatusCode
)
{
var
responseHeaders
http
.
Header
if
upstreamErr
.
Headers
!=
nil
{
responseHeaders
=
upstreamErr
.
Headers
.
Clone
()
}
return
&
UpstreamFailoverError
{
StatusCode
:
upstreamErr
.
StatusCode
,
ResponseBody
:
upstreamErr
.
Body
,
ResponseHeaders
:
responseHeaders
,
}
}
msg
:=
upstreamErr
.
Message
if
override
:=
soraProErrorMessage
(
model
,
msg
);
override
!=
""
{
msg
=
override
}
s
.
writeSoraError
(
c
,
upstreamErr
.
StatusCode
,
"upstream_error"
,
msg
,
stream
)
return
err
}
if
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
s
.
writeSoraError
(
c
,
http
.
StatusGatewayTimeout
,
"timeout_error"
,
"Sora generation timeout"
,
stream
)
return
err
}
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"api_error"
,
err
.
Error
(),
stream
)
return
err
}
func
(
s
*
SoraGatewayService
)
pollImageTask
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
taskID
string
,
stream
bool
)
([]
string
,
error
)
{
interval
:=
s
.
pollInterval
()
maxAttempts
:=
s
.
pollMaxAttempts
()
lastPing
:=
time
.
Now
()
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetImageTask
(
ctx
,
account
,
taskID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
strings
.
ToLower
(
status
.
Status
)
{
case
"succeeded"
,
"completed"
:
return
status
.
URLs
,
nil
case
"failed"
:
if
status
.
ErrorMsg
!=
""
{
return
nil
,
errors
.
New
(
status
.
ErrorMsg
)
}
return
nil
,
errors
.
New
(
"sora image generation failed"
)
}
if
stream
{
s
.
maybeSendPing
(
c
,
&
lastPing
)
}
if
err
:=
sleepWithContext
(
ctx
,
interval
);
err
!=
nil
{
return
nil
,
err
}
}
return
nil
,
errors
.
New
(
"sora image generation timeout"
)
}
func
(
s
*
SoraGatewayService
)
pollVideoTaskDetailed
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
taskID
string
,
stream
bool
)
(
*
SoraVideoTaskStatus
,
error
)
{
interval
:=
s
.
pollInterval
()
maxAttempts
:=
s
.
pollMaxAttempts
()
lastPing
:=
time
.
Now
()
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetVideoTask
(
ctx
,
account
,
taskID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
strings
.
ToLower
(
status
.
Status
)
{
case
"completed"
,
"succeeded"
:
return
status
,
nil
case
"failed"
:
if
status
.
ErrorMsg
!=
""
{
return
nil
,
errors
.
New
(
status
.
ErrorMsg
)
}
return
nil
,
errors
.
New
(
"sora video generation failed"
)
}
if
stream
{
s
.
maybeSendPing
(
c
,
&
lastPing
)
}
if
err
:=
sleepWithContext
(
ctx
,
interval
);
err
!=
nil
{
return
nil
,
err
}
}
return
nil
,
errors
.
New
(
"sora video generation timeout"
)
}
func
(
s
*
SoraGatewayService
)
pollInterval
()
time
.
Duration
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
2
*
time
.
Second
}
interval
:=
s
.
cfg
.
Sora
.
Client
.
PollIntervalSeconds
if
interval
<=
0
{
interval
=
2
}
return
time
.
Duration
(
interval
)
*
time
.
Second
}
func
(
s
*
SoraGatewayService
)
pollMaxAttempts
()
int
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
600
}
maxAttempts
:=
s
.
cfg
.
Sora
.
Client
.
MaxPollAttempts
if
maxAttempts
<=
0
{
maxAttempts
=
600
}
return
maxAttempts
}
func
(
s
*
SoraGatewayService
)
maybeSendPing
(
c
*
gin
.
Context
,
lastPing
*
time
.
Time
)
{
if
c
==
nil
{
return
}
interval
:=
10
*
time
.
Second
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Concurrency
.
PingInterval
>
0
{
interval
=
time
.
Duration
(
s
.
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
}
if
time
.
Since
(
*
lastPing
)
<
interval
{
return
}
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
":
\n\n
"
);
err
==
nil
{
if
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
);
ok
{
flusher
.
Flush
()
}
*
lastPing
=
time
.
Now
()
}
}
func
(
s
*
SoraGatewayService
)
normalizeSoraMediaURLs
(
urls
[]
string
)
[]
string
{
if
len
(
urls
)
==
0
{
return
urls
}
output
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
raw
:=
range
urls
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
continue
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
output
=
append
(
output
,
raw
)
continue
}
pathVal
:=
raw
if
!
strings
.
HasPrefix
(
pathVal
,
"/"
)
{
pathVal
=
"/"
+
pathVal
}
output
=
append
(
output
,
s
.
buildSoraMediaURL
(
pathVal
,
""
))
}
return
output
}
// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
func
jsonMarshalRaw
(
v
any
)
([]
byte
,
error
)
{
var
buf
bytes
.
Buffer
enc
:=
json
.
NewEncoder
(
&
buf
)
enc
.
SetEscapeHTML
(
false
)
if
err
:=
enc
.
Encode
(
v
);
err
!=
nil
{
return
nil
,
err
}
// Encode 会追加换行符,去掉它
b
:=
buf
.
Bytes
()
if
len
(
b
)
>
0
&&
b
[
len
(
b
)
-
1
]
==
'\n'
{
b
=
b
[
:
len
(
b
)
-
1
]
}
return
b
,
nil
}
func
buildSoraContent
(
mediaType
string
,
urls
[]
string
)
string
{
switch
mediaType
{
case
"image"
:
parts
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
u
:=
range
urls
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
""
,
u
))
}
return
strings
.
Join
(
parts
,
"
\n
"
)
case
"video"
:
if
len
(
urls
)
==
0
{
return
""
}
return
fmt
.
Sprintf
(
"```html
\n
<video src='%s' controls></video>
\n
```"
,
urls
[
0
])
default
:
return
""
}
}
func
extractSoraInput
(
body
map
[
string
]
any
)
(
prompt
,
imageInput
,
videoInput
,
remixTargetID
string
)
{
if
body
==
nil
{
return
""
,
""
,
""
,
""
}
if
v
,
ok
:=
body
[
"remix_target_id"
]
.
(
string
);
ok
{
remixTargetID
=
strings
.
TrimSpace
(
v
)
}
if
v
,
ok
:=
body
[
"image"
]
.
(
string
);
ok
{
imageInput
=
v
}
if
v
,
ok
:=
body
[
"video"
]
.
(
string
);
ok
{
videoInput
=
v
}
if
v
,
ok
:=
body
[
"prompt"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
v
)
!=
""
{
prompt
=
v
}
if
messages
,
ok
:=
body
[
"messages"
]
.
([]
any
);
ok
{
builder
:=
strings
.
Builder
{}
for
_
,
raw
:=
range
messages
{
msg
,
ok
:=
raw
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
role
,
_
:=
msg
[
"role"
]
.
(
string
)
if
role
!=
""
&&
role
!=
"user"
{
continue
}
content
:=
msg
[
"content"
]
text
,
img
,
vid
:=
parseSoraMessageContent
(
content
)
if
text
!=
""
{
if
builder
.
Len
()
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n
"
)
}
_
,
_
=
builder
.
WriteString
(
text
)
}
if
imageInput
==
""
&&
img
!=
""
{
imageInput
=
img
}
if
videoInput
==
""
&&
vid
!=
""
{
videoInput
=
vid
}
}
if
prompt
==
""
{
prompt
=
builder
.
String
()
}
}
if
remixTargetID
==
""
{
remixTargetID
=
extractRemixTargetIDFromPrompt
(
prompt
)
}
prompt
=
cleanRemixLinkFromPrompt
(
prompt
)
return
prompt
,
imageInput
,
videoInput
,
remixTargetID
}
func
parseSoraMessageContent
(
content
any
)
(
text
,
imageInput
,
videoInput
string
)
{
switch
val
:=
content
.
(
type
)
{
case
string
:
return
val
,
""
,
""
case
[]
any
:
builder
:=
strings
.
Builder
{}
for
_
,
item
:=
range
val
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
t
,
_
:=
itemMap
[
"type"
]
.
(
string
)
switch
t
{
case
"text"
:
if
txt
,
ok
:=
itemMap
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
txt
)
!=
""
{
if
builder
.
Len
()
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n
"
)
}
_
,
_
=
builder
.
WriteString
(
txt
)
}
case
"image_url"
:
if
imageInput
==
""
{
if
urlVal
,
ok
:=
itemMap
[
"image_url"
]
.
(
map
[
string
]
any
);
ok
{
imageInput
=
fmt
.
Sprintf
(
"%v"
,
urlVal
[
"url"
])
}
else
if
urlStr
,
ok
:=
itemMap
[
"image_url"
]
.
(
string
);
ok
{
imageInput
=
urlStr
}
}
case
"video_url"
:
if
videoInput
==
""
{
if
urlVal
,
ok
:=
itemMap
[
"video_url"
]
.
(
map
[
string
]
any
);
ok
{
videoInput
=
fmt
.
Sprintf
(
"%v"
,
urlVal
[
"url"
])
}
else
if
urlStr
,
ok
:=
itemMap
[
"video_url"
]
.
(
string
);
ok
{
videoInput
=
urlStr
}
}
}
}
return
builder
.
String
(),
imageInput
,
videoInput
default
:
return
""
,
""
,
""
}
}
func
isSoraStoryboardPrompt
(
prompt
string
)
bool
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
false
}
return
len
(
soraStoryboardPattern
.
FindAllString
(
prompt
,
-
1
))
>=
1
}
func
formatSoraStoryboardPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
""
}
matches
:=
soraStoryboardShotPattern
.
FindAllStringSubmatch
(
prompt
,
-
1
)
if
len
(
matches
)
==
0
{
return
prompt
}
firstBracketPos
:=
strings
.
Index
(
prompt
,
"["
)
instructions
:=
""
if
firstBracketPos
>
0
{
instructions
=
strings
.
TrimSpace
(
prompt
[
:
firstBracketPos
])
}
shots
:=
make
([]
string
,
0
,
len
(
matches
))
for
i
,
match
:=
range
matches
{
if
len
(
match
)
<
3
{
continue
}
duration
:=
strings
.
TrimSpace
(
match
[
1
])
scene
:=
strings
.
TrimSpace
(
match
[
2
])
if
scene
==
""
{
continue
}
shots
=
append
(
shots
,
fmt
.
Sprintf
(
"Shot %d:
\n
duration: %ssec
\n
Scene: %s"
,
i
+
1
,
duration
,
scene
))
}
if
len
(
shots
)
==
0
{
return
prompt
}
timeline
:=
strings
.
Join
(
shots
,
"
\n\n
"
)
if
instructions
==
""
{
return
timeline
}
return
fmt
.
Sprintf
(
"current timeline:
\n
%s
\n\n
instructions:
\n
%s"
,
timeline
,
instructions
)
}
func
extractRemixTargetIDFromPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
""
}
return
strings
.
TrimSpace
(
soraRemixTargetPattern
.
FindString
(
prompt
))
}
func
cleanRemixLinkFromPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
prompt
}
cleaned
:=
soraRemixTargetInURLPattern
.
ReplaceAllString
(
prompt
,
""
)
cleaned
=
soraRemixTargetPattern
.
ReplaceAllString
(
cleaned
,
""
)
cleaned
=
strings
.
Join
(
strings
.
Fields
(
cleaned
),
" "
)
return
strings
.
TrimSpace
(
cleaned
)
}
func
decodeSoraImageInput
(
ctx
context
.
Context
,
input
string
)
([]
byte
,
string
,
error
)
{
raw
:=
strings
.
TrimSpace
(
input
)
if
raw
==
""
{
return
nil
,
""
,
errors
.
New
(
"empty image input"
)
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
parts
:=
strings
.
SplitN
(
raw
,
","
,
2
)
if
len
(
parts
)
!=
2
{
return
nil
,
""
,
errors
.
New
(
"invalid data url"
)
}
meta
:=
parts
[
0
]
payload
:=
parts
[
1
]
decoded
,
err
:=
decodeBase64WithLimit
(
payload
,
soraImageInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
ext
:=
""
if
strings
.
HasPrefix
(
meta
,
"data:"
)
{
metaParts
:=
strings
.
SplitN
(
meta
[
5
:
],
";"
,
2
)
if
len
(
metaParts
)
>
0
{
if
exts
,
err
:=
mime
.
ExtensionsByType
(
metaParts
[
0
]);
err
==
nil
&&
len
(
exts
)
>
0
{
ext
=
exts
[
0
]
}
}
}
filename
:=
"image"
+
ext
return
decoded
,
filename
,
nil
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
return
downloadSoraImageInput
(
ctx
,
raw
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
raw
,
soraImageInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
""
,
errors
.
New
(
"invalid base64 image"
)
}
return
decoded
,
"image.png"
,
nil
}
func
decodeSoraVideoInput
(
ctx
context
.
Context
,
input
string
)
([]
byte
,
error
)
{
raw
:=
strings
.
TrimSpace
(
input
)
if
raw
==
""
{
return
nil
,
errors
.
New
(
"empty video input"
)
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
parts
:=
strings
.
SplitN
(
raw
,
","
,
2
)
if
len
(
parts
)
!=
2
{
return
nil
,
errors
.
New
(
"invalid video data url"
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
parts
[
1
],
soraVideoInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
errors
.
New
(
"invalid base64 video"
)
}
if
len
(
decoded
)
==
0
{
return
nil
,
errors
.
New
(
"empty video data"
)
}
return
decoded
,
nil
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
return
downloadSoraVideoInput
(
ctx
,
raw
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
raw
,
soraVideoInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
errors
.
New
(
"invalid base64 video"
)
}
if
len
(
decoded
)
==
0
{
return
nil
,
errors
.
New
(
"empty video data"
)
}
return
decoded
,
nil
}
func
downloadSoraImageInput
(
ctx
context
.
Context
,
rawURL
string
)
([]
byte
,
string
,
error
)
{
parsed
,
err
:=
validateSoraRemoteURL
(
rawURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
parsed
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
soraImageInputTimeout
,
CheckRedirect
:
func
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
soraImageInputMaxRedirects
{
return
errors
.
New
(
"too many redirects"
)
}
return
validateSoraRemoteURLValue
(
req
.
URL
)
},
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
""
,
fmt
.
Errorf
(
"download image failed: %d"
,
resp
.
StatusCode
)
}
data
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
soraImageInputMaxBytes
))
if
err
!=
nil
{
return
nil
,
""
,
err
}
ext
:=
fileExtFromURL
(
parsed
.
String
())
if
ext
==
""
{
ext
=
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
))
}
filename
:=
"image"
+
ext
return
data
,
filename
,
nil
}
func
downloadSoraVideoInput
(
ctx
context
.
Context
,
rawURL
string
)
([]
byte
,
error
)
{
parsed
,
err
:=
validateSoraRemoteURL
(
rawURL
)
if
err
!=
nil
{
return
nil
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
parsed
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
soraVideoInputTimeout
,
CheckRedirect
:
func
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
soraVideoInputMaxRedirects
{
return
errors
.
New
(
"too many redirects"
)
}
return
validateSoraRemoteURLValue
(
req
.
URL
)
},
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"download video failed: %d"
,
resp
.
StatusCode
)
}
data
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
soraVideoInputMaxBytes
))
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
data
)
==
0
{
return
nil
,
errors
.
New
(
"empty video content"
)
}
return
data
,
nil
}
func
decodeBase64WithLimit
(
encoded
string
,
maxBytes
int64
)
([]
byte
,
error
)
{
if
maxBytes
<=
0
{
return
nil
,
errors
.
New
(
"invalid max bytes limit"
)
}
decoder
:=
base64
.
NewDecoder
(
base64
.
StdEncoding
,
strings
.
NewReader
(
encoded
))
limited
:=
io
.
LimitReader
(
decoder
,
maxBytes
+
1
)
data
,
err
:=
io
.
ReadAll
(
limited
)
if
err
!=
nil
{
return
nil
,
err
}
if
int64
(
len
(
data
))
>
maxBytes
{
return
nil
,
fmt
.
Errorf
(
"input exceeds %d bytes limit"
,
maxBytes
)
}
return
data
,
nil
}
func
validateSoraRemoteURL
(
raw
string
)
(
*
url
.
URL
,
error
)
{
if
strings
.
TrimSpace
(
raw
)
==
""
{
return
nil
,
errors
.
New
(
"empty remote url"
)
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid remote url: %w"
,
err
)
}
if
err
:=
validateSoraRemoteURLValue
(
parsed
);
err
!=
nil
{
return
nil
,
err
}
return
parsed
,
nil
}
func
validateSoraRemoteURLValue
(
parsed
*
url
.
URL
)
error
{
if
parsed
==
nil
{
return
errors
.
New
(
"invalid remote url"
)
}
scheme
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
parsed
.
Scheme
))
if
scheme
!=
"http"
&&
scheme
!=
"https"
{
return
errors
.
New
(
"only http/https remote url is allowed"
)
}
if
parsed
.
User
!=
nil
{
return
errors
.
New
(
"remote url cannot contain userinfo"
)
}
host
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
parsed
.
Hostname
()))
if
host
==
""
{
return
errors
.
New
(
"remote url missing host"
)
}
if
_
,
blocked
:=
soraBlockedHostnames
[
host
];
blocked
{
return
errors
.
New
(
"remote url is not allowed"
)
}
if
ip
:=
net
.
ParseIP
(
host
);
ip
!=
nil
{
if
isSoraBlockedIP
(
ip
)
{
return
errors
.
New
(
"remote url is not allowed"
)
}
return
nil
}
ips
,
err
:=
net
.
LookupIP
(
host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"resolve remote url failed: %w"
,
err
)
}
for
_
,
ip
:=
range
ips
{
if
isSoraBlockedIP
(
ip
)
{
return
errors
.
New
(
"remote url is not allowed"
)
}
}
return
nil
}
func
isSoraBlockedIP
(
ip
net
.
IP
)
bool
{
if
ip
==
nil
{
return
true
}
for
_
,
cidr
:=
range
soraBlockedCIDRs
{
if
cidr
.
Contains
(
ip
)
{
return
true
}
}
return
false
}
func
mustParseCIDRs
(
values
[]
string
)
[]
*
net
.
IPNet
{
out
:=
make
([]
*
net
.
IPNet
,
0
,
len
(
values
))
for
_
,
val
:=
range
values
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
val
)
if
err
!=
nil
{
continue
}
out
=
append
(
out
,
cidr
)
}
return
out
}
backend/internal/service/sora_gateway_service_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
service
import
(
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var
_
SoraClient
=
(
*
stubSoraClientForPoll
)(
nil
)
type
stubSoraClientForPoll
struct
{
imageStatus
*
SoraImageTaskStatus
videoStatus
*
SoraVideoTaskStatus
imageCalls
int
videoCalls
int
enhanced
string
enhanceErr
error
storyboard
bool
videoReq
SoraVideoRequest
parseErr
error
postCalls
int
deleteCalls
int
}
func
(
s
*
stubSoraClientForPoll
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClientForPoll
)
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
{
s
.
videoReq
=
req
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
{
s
.
storyboard
=
true
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"cameo-1"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
return
&
SoraCameoStatus
{
Status
:
"finalized"
,
StatusMessage
:
"Completed"
,
DisplayNameHint
:
"Character"
,
UsernameHint
:
"user.character"
,
ProfileAssetURL
:
"https://example.com/avatar.webp"
,
},
nil
}
func
(
s
*
stubSoraClientForPoll
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
{
return
[]
byte
(
"avatar"
),
nil
}
func
(
s
*
stubSoraClientForPoll
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"asset-pointer"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
"character-1"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
{
s
.
postCalls
++
return
"s_post"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
{
s
.
deleteCalls
++
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
if
s
.
parseErr
!=
nil
{
return
""
,
s
.
parseErr
}
return
"https://example.com/no-watermark.mp4"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
if
s
.
enhanced
!=
""
{
return
s
.
enhanced
,
s
.
enhanceErr
}
return
"enhanced prompt"
,
s
.
enhanceErr
}
func
(
s
*
stubSoraClientForPoll
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
{
s
.
imageCalls
++
return
s
.
imageStatus
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
{
s
.
videoCalls
++
return
s
.
videoStatus
,
nil
}
func
TestSoraGatewayService_PollImageTaskCompleted
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
imageStatus
:
&
SoraImageTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/a.png"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
service
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
urls
,
err
:=
service
.
pollImageTask
(
context
.
Background
(),
nil
,
&
Account
{
ID
:
1
},
"task"
,
false
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"https://example.com/a.png"
},
urls
)
require
.
Equal
(
t
,
1
,
client
.
imageCalls
)
}
func
TestSoraGatewayService_ForwardPromptEnhance
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
enhanced
:
"cinematic prompt"
,
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"prompt-enhance-short-10s"
:
"prompt-enhance-short-15s"
,
},
},
}
body
:=
[]
byte
(
`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"prompt"
,
result
.
MediaType
)
require
.
Equal
(
t
,
"prompt-enhance-short-10s"
,
result
.
Model
)
require
.
Equal
(
t
,
"prompt-enhance-short-15s"
,
result
.
UpstreamModel
)
}
func
TestSoraGatewayService_ForwardStoryboardPrompt
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/v.mp4"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
client
.
storyboard
)
}
func
TestSoraGatewayService_ForwardVideoCount
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/v.mp4"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
3
,
client
.
videoReq
.
VideoCount
)
}
func
TestSoraGatewayService_ForwardCharacterOnly
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"prompt"
,
result
.
MediaType
)
require
.
Equal
(
t
,
0
,
client
.
videoCalls
)
}
func
TestSoraGatewayService_ForwardWatermarkFallback
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/original.mp4"
},
GenerationID
:
"gen_1"
,
},
parseErr
:
errors
.
New
(
"parse failed"
),
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"https://example.com/original.mp4"
,
result
.
MediaURL
)
require
.
Equal
(
t
,
1
,
client
.
postCalls
)
require
.
Equal
(
t
,
0
,
client
.
deleteCalls
)
}
func
TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/original.mp4"
},
GenerationID
:
"gen_1"
,
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"https://example.com/no-watermark.mp4"
,
result
.
MediaURL
)
require
.
Equal
(
t
,
1
,
client
.
postCalls
)
require
.
Equal
(
t
,
1
,
client
.
deleteCalls
)
}
func
TestSoraGatewayService_PollVideoTaskFailed
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"reject"
,
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
service
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
status
,
err
:=
service
.
pollVideoTaskDetailed
(
context
.
Background
(),
nil
,
&
Account
{
ID
:
1
},
"task"
,
false
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
status
)
require
.
Contains
(
t
,
err
.
Error
(),
"reject"
)
require
.
Equal
(
t
,
1
,
client
.
videoCalls
)
}
func
TestSoraGatewayService_BuildSoraMediaURLSigned
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
SoraMediaSigningKey
:
"test-key"
,
SoraMediaSignedURLTTLSeconds
:
600
,
},
}
service
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
url
:=
service
.
buildSoraMediaURL
(
"/image/2025/01/01/a.png"
,
""
)
require
.
Contains
(
t
,
url
,
"/sora/media-signed"
)
require
.
Contains
(
t
,
url
,
"expires="
)
require
.
Contains
(
t
,
url
,
"sig="
)
}
func
TestNormalizeSoraMediaURLs_Empty
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
result
:=
svc
.
normalizeSoraMediaURLs
(
nil
)
require
.
Empty
(
t
,
result
)
result
=
svc
.
normalizeSoraMediaURLs
([]
string
{})
require
.
Empty
(
t
,
result
)
}
func
TestNormalizeSoraMediaURLs_HTTPUrls
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
urls
:=
[]
string
{
"https://example.com/a.png"
,
"http://example.com/b.mp4"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Equal
(
t
,
urls
,
result
)
}
func
TestNormalizeSoraMediaURLs_LocalPaths
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
urls
:=
[]
string
{
"/image/2025/01/a.png"
,
"video/2025/01/b.mp4"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Len
(
t
,
result
,
2
)
require
.
Contains
(
t
,
result
[
0
],
"/sora/media"
)
require
.
Contains
(
t
,
result
[
1
],
"/sora/media"
)
}
func
TestNormalizeSoraMediaURLs_SkipsBlank
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
urls
:=
[]
string
{
"https://example.com/a.png"
,
""
,
" "
,
"https://example.com/b.png"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Len
(
t
,
result
,
2
)
}
func
TestBuildSoraContent_Image
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"image"
,
[]
string
{
"https://a.com/1.png"
,
"https://a.com/2.png"
})
require
.
Contains
(
t
,
content
,
""
)
require
.
Contains
(
t
,
content
,
""
)
}
func
TestBuildSoraContent_Video
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"video"
,
[]
string
{
"https://a.com/v.mp4"
})
require
.
Contains
(
t
,
content
,
"<video src='https://a.com/v.mp4'"
)
}
func
TestBuildSoraContent_VideoEmpty
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"video"
,
nil
)
require
.
Empty
(
t
,
content
)
}
func
TestBuildSoraContent_Prompt
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"prompt"
,
nil
)
require
.
Empty
(
t
,
content
)
}
func
TestSoraImageSizeFromModel
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"360"
,
soraImageSizeFromModel
(
"gpt-image"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"gpt-image-landscape"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"gpt-image-portrait"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"something-landscape"
))
require
.
Equal
(
t
,
"360"
,
soraImageSizeFromModel
(
"unknown-model"
))
}
func
TestFirstMediaURL
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
firstMediaURL
(
nil
))
require
.
Equal
(
t
,
""
,
firstMediaURL
([]
string
{}))
require
.
Equal
(
t
,
"a"
,
firstMediaURL
([]
string
{
"a"
,
"b"
}))
}
func
TestSoraProErrorMessage
(
t
*
testing
.
T
)
{
require
.
Contains
(
t
,
soraProErrorMessage
(
"sora2pro-hd"
,
""
),
"Pro-HD"
)
require
.
Contains
(
t
,
soraProErrorMessage
(
"sora2pro"
,
""
),
"Pro"
)
require
.
Empty
(
t
,
soraProErrorMessage
(
"sora-basic"
,
""
))
}
func
TestSoraGatewayService_WriteSoraError_StreamEscapesJSON
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
svc
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"invalid
\"
prompt
\"\n
line2"
,
true
)
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: error
\n
"
)
require
.
Contains
(
t
,
body
,
"data: [DONE]
\n\n
"
)
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
require
.
GreaterOrEqual
(
t
,
len
(
lines
),
2
)
require
.
Equal
(
t
,
"event: error"
,
lines
[
0
])
require
.
True
(
t
,
strings
.
HasPrefix
(
lines
[
1
],
"data: "
))
data
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
data
),
&
parsed
))
errObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errObj
[
"type"
])
require
.
Equal
(
t
,
"invalid
\"
prompt
\"\n
line2"
,
errObj
[
"message"
])
}
func
TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
sourceHeaders
:=
http
.
Header
{}
sourceHeaders
.
Set
(
"cf-ray"
,
"9d01b0e9ecc35829-SEA"
)
err
:=
svc
.
handleSoraRequestError
(
context
.
Background
(),
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
},
&
SoraUpstreamError
{
StatusCode
:
http
.
StatusForbidden
,
Message
:
"forbidden"
,
Headers
:
sourceHeaders
,
Body
:
[]
byte
(
`<!DOCTYPE html><title>Just a moment...</title>`
),
},
"sora2-landscape-10s"
,
nil
,
false
,
)
var
failoverErr
*
UpstreamFailoverError
require
.
ErrorAs
(
t
,
err
,
&
failoverErr
)
require
.
NotNil
(
t
,
failoverErr
.
ResponseHeaders
)
require
.
Equal
(
t
,
"9d01b0e9ecc35829-SEA"
,
failoverErr
.
ResponseHeaders
.
Get
(
"cf-ray"
))
sourceHeaders
.
Set
(
"cf-ray"
,
"mutated-after-return"
)
require
.
Equal
(
t
,
"9d01b0e9ecc35829-SEA"
,
failoverErr
.
ResponseHeaders
.
Get
(
"cf-ray"
))
}
func
TestShouldFailoverUpstreamError
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
401
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
404
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
429
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
500
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
502
))
require
.
False
(
t
,
svc
.
shouldFailoverUpstreamError
(
200
))
require
.
False
(
t
,
svc
.
shouldFailoverUpstreamError
(
400
))
}
func
TestWithSoraTimeout_NilService
(
t
*
testing
.
T
)
{
var
svc
*
SoraGatewayService
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
Nil
(
t
,
cancel
)
}
func
TestWithSoraTimeout_ZeroTimeout
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
Nil
(
t
,
cancel
)
}
func
TestWithSoraTimeout_PositiveTimeout
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
SoraRequestTimeoutSeconds
:
30
,
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
NotNil
(
t
,
cancel
)
cancel
()
}
func
TestPollInterval
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
5
,
},
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
require
.
Equal
(
t
,
5
*
time
.
Second
,
svc
.
pollInterval
())
// 默认值
svc2
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc2
.
pollInterval
()
>
0
)
}
func
TestPollMaxAttempts
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
MaxPollAttempts
:
100
,
},
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
require
.
Equal
(
t
,
100
,
svc
.
pollMaxAttempts
())
// 默认值
svc2
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc2
.
pollMaxAttempts
()
>
0
)
}
func
TestDecodeSoraImageInput_BlockPrivateURL
(
t
*
testing
.
T
)
{
_
,
_
,
err
:=
decodeSoraImageInput
(
context
.
Background
(),
"http://127.0.0.1/internal.png"
)
require
.
Error
(
t
,
err
)
}
func
TestDecodeSoraImageInput_DataURL
(
t
*
testing
.
T
)
{
encoded
:=
"data:image/png;base64,aGVsbG8="
data
,
filename
,
err
:=
decodeSoraImageInput
(
context
.
Background
(),
encoded
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
data
)
require
.
Contains
(
t
,
filename
,
".png"
)
}
func
TestDecodeBase64WithLimit_ExceedLimit
(
t
*
testing
.
T
)
{
data
,
err
:=
decodeBase64WithLimit
(
"aGVsbG8="
,
3
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
data
)
}
func
TestParseSoraWatermarkOptions_NumericBool
(
t
*
testing
.
T
)
{
body
:=
map
[
string
]
any
{
"watermark_free"
:
float64
(
1
),
"watermark_fallback_on_failure"
:
float64
(
0
),
}
opts
:=
parseSoraWatermarkOptions
(
body
)
require
.
True
(
t
,
opts
.
Enabled
)
require
.
False
(
t
,
opts
.
FallbackOnFailure
)
}
func
TestParseSoraVideoCount
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
1
,
parseSoraVideoCount
(
nil
))
require
.
Equal
(
t
,
2
,
parseSoraVideoCount
(
map
[
string
]
any
{
"video_count"
:
float64
(
2
)}))
require
.
Equal
(
t
,
3
,
parseSoraVideoCount
(
map
[
string
]
any
{
"videos"
:
"5"
}))
require
.
Equal
(
t
,
1
,
parseSoraVideoCount
(
map
[
string
]
any
{
"n_variants"
:
0
}))
}
backend/internal/service/sora_gateway_streaming_legacy.go
deleted
100644 → 0
View file @
7b83d6e7
//nolint:unused
package
service
import
(
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var
soraSSEDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
var
soraImageMarkdownRe
=
regexp
.
MustCompile
(
`!\[[^\]]*\]\(([^)]+)\)`
)
var
soraVideoHTMLRe
=
regexp
.
MustCompile
(
`(?i)<video[^>]+src=['"]([^'"]+)['"]`
)
const
soraRewriteBufferLimit
=
2048
type
soraStreamingResult
struct
{
mediaType
string
mediaURLs
[]
string
imageCount
int
imageSize
string
firstTokenMs
*
int
}
func
(
s
*
SoraGatewayService
)
setUpstreamRequestError
(
c
*
gin
.
Context
,
account
*
Account
,
err
error
)
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
}
}
func
(
s
*
SoraGatewayService
)
handleFailoverSideEffects
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
account
*
Account
)
{
if
s
.
rateLimitService
==
nil
||
account
==
nil
||
resp
==
nil
{
return
}
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
func
(
s
*
SoraGatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
reqModel
string
)
(
*
ForwardResult
,
error
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
msg
:=
soraProErrorMessage
(
reqModel
,
upstreamMsg
);
msg
!=
""
{
upstreamMsg
=
msg
}
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
c
!=
nil
{
responsePayload
:=
s
.
buildErrorPayload
(
respBody
,
upstreamMsg
)
c
.
JSON
(
resp
.
StatusCode
,
responsePayload
)
}
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
func
(
s
*
SoraGatewayService
)
buildErrorPayload
(
respBody
[]
byte
,
overrideMessage
string
)
map
[
string
]
any
{
if
len
(
respBody
)
>
0
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
respBody
,
&
payload
);
err
==
nil
{
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
overrideMessage
!=
""
{
errObj
[
"message"
]
=
overrideMessage
}
payload
[
"error"
]
=
errObj
return
payload
}
}
}
return
map
[
string
]
any
{
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
overrideMessage
,
},
}
}
func
(
s
*
SoraGatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
string
,
clientStream
bool
)
(
*
soraStreamingResult
,
error
)
{
if
resp
==
nil
{
return
nil
,
errors
.
New
(
"empty response"
)
}
if
clientStream
{
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
v
:=
resp
.
Header
.
Get
(
"x-request-id"
);
v
!=
""
{
c
.
Header
(
"x-request-id"
,
v
)
}
}
w
:=
c
.
Writer
flusher
,
_
:=
w
.
(
http
.
Flusher
)
contentBuilder
:=
strings
.
Builder
{}
var
firstTokenMs
*
int
var
upstreamError
error
rewriteBuffer
:=
""
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
sendLine
:=
func
(
line
string
)
error
{
if
!
clientStream
{
return
nil
}
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
nil
}
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
if
soraSSEDataRe
.
MatchString
(
line
)
{
data
:=
soraSSEDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
"[DONE]"
{
if
rewriteBuffer
!=
""
{
flushLine
,
flushContent
,
err
:=
s
.
flushSoraRewriteBuffer
(
rewriteBuffer
,
originalModel
)
if
err
!=
nil
{
return
nil
,
err
}
if
flushLine
!=
""
{
if
flushContent
!=
""
{
if
_
,
err
:=
contentBuilder
.
WriteString
(
flushContent
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
sendLine
(
flushLine
);
err
!=
nil
{
return
nil
,
err
}
}
rewriteBuffer
=
""
}
if
err
:=
sendLine
(
"data: [DONE]"
);
err
!=
nil
{
return
nil
,
err
}
break
}
updatedLine
,
contentDelta
,
errEvent
:=
s
.
processSoraSSEData
(
data
,
originalModel
,
&
rewriteBuffer
)
if
errEvent
!=
nil
&&
upstreamError
==
nil
{
upstreamError
=
errEvent
}
if
contentDelta
!=
""
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
if
_
,
err
:=
contentBuilder
.
WriteString
(
contentDelta
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
sendLine
(
updatedLine
);
err
!=
nil
{
return
nil
,
err
}
continue
}
if
err
:=
sendLine
(
line
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
if
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
if
clientStream
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
response_too_large
\"
}
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
}
return
nil
,
err
}
if
ctx
.
Err
()
==
context
.
DeadlineExceeded
&&
s
.
rateLimitService
!=
nil
&&
account
!=
nil
{
s
.
rateLimitService
.
HandleStreamTimeout
(
ctx
,
account
,
originalModel
)
}
if
clientStream
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
stream_read_error
\"
}
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
}
return
nil
,
err
}
content
:=
contentBuilder
.
String
()
mediaType
,
mediaURLs
:=
s
.
extractSoraMedia
(
content
)
if
mediaType
==
""
&&
isSoraPromptEnhanceModel
(
originalModel
)
{
mediaType
=
"prompt"
}
imageSize
:=
""
imageCount
:=
0
if
mediaType
==
"image"
{
imageSize
=
soraImageSizeFromModel
(
originalModel
)
imageCount
=
len
(
mediaURLs
)
}
if
upstreamError
!=
nil
&&
!
clientStream
{
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusBadGateway
,
map
[
string
]
any
{
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
upstreamError
.
Error
(),
},
})
}
return
nil
,
upstreamError
}
if
!
clientStream
{
response
:=
buildSoraNonStreamResponse
(
content
,
originalModel
)
if
len
(
mediaURLs
)
>
0
{
response
[
"media_url"
]
=
mediaURLs
[
0
]
if
len
(
mediaURLs
)
>
1
{
response
[
"media_urls"
]
=
mediaURLs
}
}
c
.
JSON
(
http
.
StatusOK
,
response
)
}
return
&
soraStreamingResult
{
mediaType
:
mediaType
,
mediaURLs
:
mediaURLs
,
imageCount
:
imageCount
,
imageSize
:
imageSize
,
firstTokenMs
:
firstTokenMs
,
},
nil
}
func
(
s
*
SoraGatewayService
)
processSoraSSEData
(
data
string
,
originalModel
string
,
rewriteBuffer
*
string
)
(
string
,
string
,
error
)
{
if
strings
.
TrimSpace
(
data
)
==
""
{
return
"data: "
,
""
,
nil
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
payload
);
err
!=
nil
{
return
"data: "
+
data
,
""
,
nil
}
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
return
"data: "
+
data
,
""
,
errors
.
New
(
msg
)
}
}
if
model
,
ok
:=
payload
[
"model"
]
.
(
string
);
ok
&&
model
!=
""
&&
originalModel
!=
""
{
payload
[
"model"
]
=
originalModel
}
contentDelta
,
updated
:=
extractSoraContent
(
payload
)
if
updated
{
var
rewritten
string
if
rewriteBuffer
!=
nil
{
rewritten
=
s
.
rewriteSoraContentWithBuffer
(
contentDelta
,
rewriteBuffer
)
}
else
{
rewritten
=
s
.
rewriteSoraContent
(
contentDelta
)
}
if
rewritten
!=
contentDelta
{
applySoraContent
(
payload
,
rewritten
)
contentDelta
=
rewritten
}
}
updatedData
,
err
:=
jsonMarshalRaw
(
payload
)
if
err
!=
nil
{
return
"data: "
+
data
,
contentDelta
,
nil
}
return
"data: "
+
string
(
updatedData
),
contentDelta
,
nil
}
func
extractSoraContent
(
payload
map
[
string
]
any
)
(
string
,
bool
)
{
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
return
""
,
false
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
""
,
false
}
if
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
delta
[
"content"
]
.
(
string
);
ok
{
return
content
,
true
}
}
if
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
message
[
"content"
]
.
(
string
);
ok
{
return
content
,
true
}
}
return
""
,
false
}
func
applySoraContent
(
payload
map
[
string
]
any
,
content
string
)
{
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
return
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
if
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
delta
[
"content"
]
=
content
choice
[
"delta"
]
=
delta
return
}
if
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
message
[
"content"
]
=
content
choice
[
"message"
]
=
message
}
}
func
(
s
*
SoraGatewayService
)
rewriteSoraContentWithBuffer
(
contentDelta
string
,
buffer
*
string
)
string
{
if
buffer
==
nil
{
return
s
.
rewriteSoraContent
(
contentDelta
)
}
if
contentDelta
==
""
&&
*
buffer
==
""
{
return
""
}
combined
:=
*
buffer
+
contentDelta
rewritten
:=
s
.
rewriteSoraContent
(
combined
)
bufferStart
:=
s
.
findSoraRewriteBufferStart
(
rewritten
)
if
bufferStart
<
0
{
*
buffer
=
""
return
rewritten
}
if
len
(
rewritten
)
-
bufferStart
>
soraRewriteBufferLimit
{
bufferStart
=
len
(
rewritten
)
-
soraRewriteBufferLimit
}
output
:=
rewritten
[
:
bufferStart
]
*
buffer
=
rewritten
[
bufferStart
:
]
return
output
}
func
(
s
*
SoraGatewayService
)
findSoraRewriteBufferStart
(
content
string
)
int
{
minIndex
:=
-
1
start
:=
0
for
{
idx
:=
strings
.
Index
(
content
[
start
:
],
"!["
)
if
idx
<
0
{
break
}
idx
+=
start
if
!
hasSoraImageMatchAt
(
content
,
idx
)
{
if
minIndex
==
-
1
||
idx
<
minIndex
{
minIndex
=
idx
}
}
start
=
idx
+
2
}
lower
:=
strings
.
ToLower
(
content
)
start
=
0
for
{
idx
:=
strings
.
Index
(
lower
[
start
:
],
"<video"
)
if
idx
<
0
{
break
}
idx
+=
start
if
!
hasSoraVideoMatchAt
(
content
,
idx
)
{
if
minIndex
==
-
1
||
idx
<
minIndex
{
minIndex
=
idx
}
}
start
=
idx
+
len
(
"<video"
)
}
return
minIndex
}
func
hasSoraImageMatchAt
(
content
string
,
idx
int
)
bool
{
if
idx
<
0
||
idx
>=
len
(
content
)
{
return
false
}
loc
:=
soraImageMarkdownRe
.
FindStringIndex
(
content
[
idx
:
])
return
loc
!=
nil
&&
loc
[
0
]
==
0
}
func
hasSoraVideoMatchAt
(
content
string
,
idx
int
)
bool
{
if
idx
<
0
||
idx
>=
len
(
content
)
{
return
false
}
loc
:=
soraVideoHTMLRe
.
FindStringIndex
(
content
[
idx
:
])
return
loc
!=
nil
&&
loc
[
0
]
==
0
}
func
(
s
*
SoraGatewayService
)
rewriteSoraContent
(
content
string
)
string
{
if
content
==
""
{
return
content
}
content
=
soraImageMarkdownRe
.
ReplaceAllStringFunc
(
content
,
func
(
match
string
)
string
{
sub
:=
soraImageMarkdownRe
.
FindStringSubmatch
(
match
)
if
len
(
sub
)
<
2
{
return
match
}
rewritten
:=
s
.
rewriteSoraURL
(
sub
[
1
])
if
rewritten
==
sub
[
1
]
{
return
match
}
return
strings
.
Replace
(
match
,
sub
[
1
],
rewritten
,
1
)
})
content
=
soraVideoHTMLRe
.
ReplaceAllStringFunc
(
content
,
func
(
match
string
)
string
{
sub
:=
soraVideoHTMLRe
.
FindStringSubmatch
(
match
)
if
len
(
sub
)
<
2
{
return
match
}
rewritten
:=
s
.
rewriteSoraURL
(
sub
[
1
])
if
rewritten
==
sub
[
1
]
{
return
match
}
return
strings
.
Replace
(
match
,
sub
[
1
],
rewritten
,
1
)
})
return
content
}
func
(
s
*
SoraGatewayService
)
flushSoraRewriteBuffer
(
buffer
string
,
originalModel
string
)
(
string
,
string
,
error
)
{
if
buffer
==
""
{
return
""
,
""
,
nil
}
rewritten
:=
s
.
rewriteSoraContent
(
buffer
)
payload
:=
map
[
string
]
any
{
"choices"
:
[]
any
{
map
[
string
]
any
{
"delta"
:
map
[
string
]
any
{
"content"
:
rewritten
,
},
"index"
:
0
,
},
},
}
if
originalModel
!=
""
{
payload
[
"model"
]
=
originalModel
}
updatedData
,
err
:=
jsonMarshalRaw
(
payload
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
"data: "
+
string
(
updatedData
),
rewritten
,
nil
}
func
(
s
*
SoraGatewayService
)
rewriteSoraURL
(
raw
string
)
string
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
raw
}
path
:=
parsed
.
Path
if
!
strings
.
HasPrefix
(
path
,
"/tmp/"
)
&&
!
strings
.
HasPrefix
(
path
,
"/static/"
)
{
return
raw
}
return
s
.
buildSoraMediaURL
(
path
,
parsed
.
RawQuery
)
}
func
(
s
*
SoraGatewayService
)
extractSoraMedia
(
content
string
)
(
string
,
[]
string
)
{
if
content
==
""
{
return
""
,
nil
}
if
match
:=
soraVideoHTMLRe
.
FindStringSubmatch
(
content
);
len
(
match
)
>
1
{
return
"video"
,
[]
string
{
match
[
1
]}
}
imageMatches
:=
soraImageMarkdownRe
.
FindAllStringSubmatch
(
content
,
-
1
)
if
len
(
imageMatches
)
==
0
{
return
""
,
nil
}
urls
:=
make
([]
string
,
0
,
len
(
imageMatches
))
for
_
,
match
:=
range
imageMatches
{
if
len
(
match
)
>
1
{
urls
=
append
(
urls
,
match
[
1
])
}
}
return
"image"
,
urls
}
func
isSoraPromptEnhanceModel
(
model
string
)
bool
{
return
strings
.
HasPrefix
(
strings
.
ToLower
(
strings
.
TrimSpace
(
model
)),
"prompt-enhance"
)
}
backend/internal/service/sora_generation.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type
SoraGeneration
struct
{
ID
int64
`json:"id"`
UserID
int64
`json:"user_id"`
APIKeyID
*
int64
`json:"api_key_id,omitempty"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
MediaType
string
`json:"media_type"`
// video / image
Status
string
`json:"status"`
// pending / generating / completed / failed / cancelled
MediaURL
string
`json:"media_url"`
// 主媒体 URL(预签名或 CDN)
MediaURLs
[]
string
`json:"media_urls"`
// 多图时的 URL 数组
FileSizeBytes
int64
`json:"file_size_bytes"`
StorageType
string
`json:"storage_type"`
// s3 / local / upstream / none
S3ObjectKeys
[]
string
`json:"s3_object_keys"`
// S3 object key 数组
UpstreamTaskID
string
`json:"upstream_task_id"`
ErrorMessage
string
`json:"error_message"`
CreatedAt
time
.
Time
`json:"created_at"`
CompletedAt
*
time
.
Time
`json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const
(
SoraGenStatusPending
=
"pending"
SoraGenStatusGenerating
=
"generating"
SoraGenStatusCompleted
=
"completed"
SoraGenStatusFailed
=
"failed"
SoraGenStatusCancelled
=
"cancelled"
)
// Sora 存储类型常量
const
(
SoraStorageTypeS3
=
"s3"
SoraStorageTypeLocal
=
"local"
SoraStorageTypeUpstream
=
"upstream"
SoraStorageTypeNone
=
"none"
)
// SoraGenerationListParams 查询生成记录的参数。
type
SoraGenerationListParams
struct
{
UserID
int64
Status
string
// 可选筛选
StorageType
string
// 可选筛选
MediaType
string
// 可选筛选
Page
int
PageSize
int
}
// SoraGenerationRepository 生成记录持久化接口。
type
SoraGenerationRepository
interface
{
Create
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
SoraGeneration
,
error
)
Update
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
CountByUserAndStatus
(
ctx
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
}
backend/internal/service/sora_generation_service.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var
(
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit
=
errors
.
New
(
"sora generation concurrent limit exceeded"
)
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict
=
errors
.
New
(
"sora generation state conflict"
)
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive
=
errors
.
New
(
"sora generation is not active"
)
)
const
soraGenerationActiveLimit
=
3
type
soraGenerationRepoAtomicCreator
interface
{
CreatePendingWithLimit
(
ctx
context
.
Context
,
gen
*
SoraGeneration
,
activeStatuses
[]
string
,
maxActive
int64
)
error
}
type
soraGenerationRepoConditionalUpdater
interface
{
UpdateGeneratingIfPending
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
(
bool
,
error
)
UpdateCompletedIfActive
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateFailedIfActive
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateCancelledIfActive
(
ctx
context
.
Context
,
id
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateStorageIfCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
)
(
bool
,
error
)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type
SoraGenerationService
struct
{
genRepo
SoraGenerationRepository
s3Storage
*
SoraS3Storage
quotaService
*
SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func
NewSoraGenerationService
(
genRepo
SoraGenerationRepository
,
s3Storage
*
SoraS3Storage
,
quotaService
*
SoraQuotaService
,
)
*
SoraGenerationService
{
return
&
SoraGenerationService
{
genRepo
:
genRepo
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func
(
s
*
SoraGenerationService
)
CreatePending
(
ctx
context
.
Context
,
userID
int64
,
apiKeyID
*
int64
,
model
,
prompt
,
mediaType
string
)
(
*
SoraGeneration
,
error
)
{
gen
:=
&
SoraGeneration
{
UserID
:
userID
,
APIKeyID
:
apiKeyID
,
Model
:
model
,
Prompt
:
prompt
,
MediaType
:
mediaType
,
Status
:
SoraGenStatusPending
,
StorageType
:
SoraStorageTypeNone
,
}
if
atomicCreator
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoAtomicCreator
);
ok
{
if
err
:=
atomicCreator
.
CreatePendingWithLimit
(
ctx
,
gen
,
[]
string
{
SoraGenStatusPending
,
SoraGenStatusGenerating
},
soraGenerationActiveLimit
,
);
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSoraGenerationConcurrencyLimit
)
{
return
nil
,
err
}
return
nil
,
fmt
.
Errorf
(
"create generation: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 创建记录 id=%d user=%d model=%s"
,
gen
.
ID
,
userID
,
model
)
return
gen
,
nil
}
if
err
:=
s
.
genRepo
.
Create
(
ctx
,
gen
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create generation: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 创建记录 id=%d user=%d model=%s"
,
gen
.
ID
,
userID
,
model
)
return
gen
,
nil
}
// MarkGenerating 标记为生成中。
func
(
s
*
SoraGenerationService
)
MarkGenerating
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
error
{
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateGeneratingIfPending
(
ctx
,
id
,
upstreamTaskID
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusGenerating
gen
.
UpstreamTaskID
=
upstreamTaskID
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkCompleted 标记为已完成。
func
(
s
*
SoraGenerationService
)
MarkCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateCompletedIfActive
(
ctx
,
id
,
mediaURL
,
mediaURLs
,
storageType
,
s3Keys
,
fileSizeBytes
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusCompleted
gen
.
MediaURL
=
mediaURL
gen
.
MediaURLs
=
mediaURLs
gen
.
StorageType
=
storageType
gen
.
S3ObjectKeys
=
s3Keys
gen
.
FileSizeBytes
=
fileSizeBytes
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkFailed 标记为失败。
func
(
s
*
SoraGenerationService
)
MarkFailed
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateFailedIfActive
(
ctx
,
id
,
errMsg
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusFailed
gen
.
ErrorMessage
=
errMsg
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkCancelled 标记为已取消。
func
(
s
*
SoraGenerationService
)
MarkCancelled
(
ctx
context
.
Context
,
id
int64
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateCancelledIfActive
(
ctx
,
id
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationNotActive
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationNotActive
}
gen
.
Status
=
SoraGenStatusCancelled
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
func
(
s
*
SoraGenerationService
)
UpdateStorageForCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
)
error
{
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateStorageIfCompleted
(
ctx
,
id
,
mediaURL
,
mediaURLs
,
storageType
,
s3Keys
,
fileSizeBytes
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusCompleted
{
return
ErrSoraGenerationStateConflict
}
gen
.
MediaURL
=
mediaURL
gen
.
MediaURLs
=
mediaURLs
gen
.
StorageType
=
storageType
gen
.
S3ObjectKeys
=
s3Keys
gen
.
FileSizeBytes
=
fileSizeBytes
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// GetByID 获取记录详情(含权限校验)。
func
(
s
*
SoraGenerationService
)
GetByID
(
ctx
context
.
Context
,
id
,
userID
int64
)
(
*
SoraGeneration
,
error
)
{
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
gen
.
UserID
!=
userID
{
return
nil
,
fmt
.
Errorf
(
"无权访问此生成记录"
)
}
return
gen
,
nil
}
// List 查询生成记录列表(分页 + 筛选)。
func
(
s
*
SoraGenerationService
)
List
(
ctx
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
{
if
params
.
Page
<=
0
{
params
.
Page
=
1
}
if
params
.
PageSize
<=
0
{
params
.
PageSize
=
20
}
if
params
.
PageSize
>
100
{
params
.
PageSize
=
100
}
return
s
.
genRepo
.
List
(
ctx
,
params
)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func
(
s
*
SoraGenerationService
)
Delete
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
UserID
!=
userID
{
return
fmt
.
Errorf
(
"无权删除此生成记录"
)
}
// 清理 S3 文件
if
gen
.
StorageType
==
SoraStorageTypeS3
&&
len
(
gen
.
S3ObjectKeys
)
>
0
&&
s
.
s3Storage
!=
nil
{
if
err
:=
s
.
s3Storage
.
DeleteObjects
(
ctx
,
gen
.
S3ObjectKeys
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] S3 清理失败 id=%d err=%v"
,
id
,
err
)
}
}
// 释放配额(S3/本地均释放)
if
gen
.
FileSizeBytes
>
0
&&
(
gen
.
StorageType
==
SoraStorageTypeS3
||
gen
.
StorageType
==
SoraStorageTypeLocal
)
&&
s
.
quotaService
!=
nil
{
if
err
:=
s
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
gen
.
FileSizeBytes
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 配额释放失败 id=%d err=%v"
,
id
,
err
)
}
}
return
s
.
genRepo
.
Delete
(
ctx
,
id
)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func
(
s
*
SoraGenerationService
)
CountActiveByUser
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
return
s
.
genRepo
.
CountByUserAndStatus
(
ctx
,
userID
,
[]
string
{
SoraGenStatusPending
,
SoraGenStatusGenerating
})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func
(
s
*
SoraGenerationService
)
ResolveMediaURLs
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
gen
==
nil
||
gen
.
StorageType
!=
SoraStorageTypeS3
||
s
.
s3Storage
==
nil
{
return
nil
}
if
len
(
gen
.
S3ObjectKeys
)
==
0
{
return
nil
}
urls
:=
make
([]
string
,
len
(
gen
.
S3ObjectKeys
))
var
wg
sync
.
WaitGroup
var
firstErr
error
var
errMu
sync
.
Mutex
for
idx
,
key
:=
range
gen
.
S3ObjectKeys
{
wg
.
Add
(
1
)
go
func
(
i
int
,
objectKey
string
)
{
defer
wg
.
Done
()
url
,
err
:=
s
.
s3Storage
.
GetAccessURL
(
ctx
,
objectKey
)
if
err
!=
nil
{
errMu
.
Lock
()
if
firstErr
==
nil
{
firstErr
=
err
}
errMu
.
Unlock
()
return
}
urls
[
i
]
=
url
}(
idx
,
key
)
}
wg
.
Wait
()
if
firstErr
!=
nil
{
return
firstErr
}
gen
.
MediaURL
=
urls
[
0
]
gen
.
MediaURLs
=
urls
return
nil
}
backend/internal/service/sora_generation_service_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
service
import
(
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// ==================== Stub: SoraGenerationRepository ====================
var
_
SoraGenerationRepository
=
(
*
stubGenRepo
)(
nil
)
type
stubGenRepo
struct
{
gens
map
[
int64
]
*
SoraGeneration
nextID
int64
createErr
error
getErr
error
updateErr
error
deleteErr
error
listErr
error
countErr
error
countValue
int64
}
func
newStubGenRepo
()
*
stubGenRepo
{
return
&
stubGenRepo
{
gens
:
make
(
map
[
int64
]
*
SoraGeneration
),
nextID
:
1
}
}
func
(
r
*
stubGenRepo
)
Create
(
_
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
r
.
createErr
!=
nil
{
return
r
.
createErr
}
gen
.
ID
=
r
.
nextID
gen
.
CreatedAt
=
time
.
Now
()
r
.
nextID
++
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubGenRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
SoraGeneration
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
if
gen
,
ok
:=
r
.
gens
[
id
];
ok
{
return
gen
,
nil
}
return
nil
,
fmt
.
Errorf
(
"not found"
)
}
func
(
r
*
stubGenRepo
)
Update
(
_
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubGenRepo
)
Delete
(
_
context
.
Context
,
id
int64
)
error
{
if
r
.
deleteErr
!=
nil
{
return
r
.
deleteErr
}
delete
(
r
.
gens
,
id
)
return
nil
}
func
(
r
*
stubGenRepo
)
List
(
_
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
{
if
r
.
listErr
!=
nil
{
return
nil
,
0
,
r
.
listErr
}
var
result
[]
*
SoraGeneration
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
!=
params
.
UserID
{
continue
}
if
params
.
Status
!=
""
&&
gen
.
Status
!=
params
.
Status
{
continue
}
if
params
.
StorageType
!=
""
&&
gen
.
StorageType
!=
params
.
StorageType
{
continue
}
if
params
.
MediaType
!=
""
&&
gen
.
MediaType
!=
params
.
MediaType
{
continue
}
result
=
append
(
result
,
gen
)
}
return
result
,
int64
(
len
(
result
)),
nil
}
func
(
r
*
stubGenRepo
)
CountByUserAndStatus
(
_
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
{
if
r
.
countErr
!=
nil
{
return
0
,
r
.
countErr
}
if
r
.
countValue
>
0
{
return
r
.
countValue
,
nil
}
var
count
int64
statusSet
:=
make
(
map
[
string
]
struct
{})
for
_
,
s
:=
range
statuses
{
statusSet
[
s
]
=
struct
{}{}
}
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
==
userID
{
if
_
,
ok
:=
statusSet
[
gen
.
Status
];
ok
{
count
++
}
}
}
return
count
,
nil
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var
_
UserRepository
=
(
*
stubUserRepoForQuota
)(
nil
)
type
stubUserRepoForQuota
struct
{
users
map
[
int64
]
*
User
updateErr
error
}
func
newStubUserRepoForQuota
()
*
stubUserRepoForQuota
{
return
&
stubUserRepoForQuota
{
users
:
make
(
map
[
int64
]
*
User
)}
}
func
(
r
*
stubUserRepoForQuota
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
User
,
error
)
{
if
u
,
ok
:=
r
.
users
[
id
];
ok
{
return
u
,
nil
}
return
nil
,
fmt
.
Errorf
(
"user not found"
)
}
func
(
r
*
stubUserRepoForQuota
)
Update
(
_
context
.
Context
,
user
*
User
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
users
[
user
.
ID
]
=
user
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
Create
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
GetByEmail
(
context
.
Context
,
string
)
(
*
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
GetFirstAdmin
(
context
.
Context
)
(
*
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
ExistsByEmail
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
func
newS3StorageWithCDN
(
cdnURL
string
)
*
SoraS3Storage
{
storage
:=
&
SoraS3Storage
{}
storage
.
cfg
=
&
SoraS3Settings
{
Enabled
:
true
,
Bucket
:
"test-bucket"
,
CDNURL
:
cdnURL
,
}
// 需要 non-nil client 使 getClient 命中缓存
storage
.
client
=
s3
.
New
(
s3
.
Options
{})
return
storage
}
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
func
newS3StorageFailingDelete
()
*
SoraS3Storage
{
return
&
SoraS3Storage
{}
// settingService 为 nil → getConfig 返回 error
}
// ==================== CreatePending ====================
func
TestCreatePending_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"一只猫跳舞"
,
"video"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
UserID
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
gen
.
Model
)
require
.
Equal
(
t
,
"一只猫跳舞"
,
gen
.
Prompt
)
require
.
Equal
(
t
,
"video"
,
gen
.
MediaType
)
require
.
Equal
(
t
,
SoraGenStatusPending
,
gen
.
Status
)
require
.
Equal
(
t
,
SoraStorageTypeNone
,
gen
.
StorageType
)
require
.
Nil
(
t
,
gen
.
APIKeyID
)
}
func
TestCreatePending_WithAPIKeyID
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyID
:=
int64
(
42
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
&
apiKeyID
,
"gpt-image"
,
"画一朵花"
,
"image"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
gen
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
gen
.
APIKeyID
)
}
func
TestCreatePending_RepoError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
createErr
=
fmt
.
Errorf
(
"db write error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
require
.
Contains
(
t
,
err
.
Error
(),
"create generation"
)
}
// ==================== MarkGenerating ====================
func
TestMarkGenerating_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
1
,
"upstream-task-123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusGenerating
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
"upstream-task-123"
,
repo
.
gens
[
1
]
.
UpstreamTaskID
)
}
func
TestMarkGenerating_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
999
,
""
)
require
.
Error
(
t
,
err
)
}
func
TestMarkGenerating_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
1
,
""
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkCompleted ====================
func
TestMarkCompleted_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
1
,
"https://cdn.example.com/video.mp4"
,
[]
string
{
"https://cdn.example.com/video.mp4"
},
SoraStorageTypeS3
,
[]
string
{
"sora/1/2024/01/01/uuid.mp4"
},
1048576
,
)
require
.
NoError
(
t
,
err
)
gen
:=
repo
.
gens
[
1
]
require
.
Equal
(
t
,
SoraGenStatusCompleted
,
gen
.
Status
)
require
.
Equal
(
t
,
"https://cdn.example.com/video.mp4"
,
gen
.
MediaURL
)
require
.
Equal
(
t
,
[]
string
{
"https://cdn.example.com/video.mp4"
},
gen
.
MediaURLs
)
require
.
Equal
(
t
,
SoraStorageTypeS3
,
gen
.
StorageType
)
require
.
Equal
(
t
,
[]
string
{
"sora/1/2024/01/01/uuid.mp4"
},
gen
.
S3ObjectKeys
)
require
.
Equal
(
t
,
int64
(
1048576
),
gen
.
FileSizeBytes
)
require
.
NotNil
(
t
,
gen
.
CompletedAt
)
}
func
TestMarkCompleted_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
999
,
""
,
nil
,
""
,
nil
,
0
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCompleted_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
1
,
"url"
,
nil
,
SoraStorageTypeUpstream
,
nil
,
0
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkFailed ====================
func
TestMarkFailed_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
1
,
"上游返回 500 错误"
)
require
.
NoError
(
t
,
err
)
gen
:=
repo
.
gens
[
1
]
require
.
Equal
(
t
,
SoraGenStatusFailed
,
gen
.
Status
)
require
.
Equal
(
t
,
"上游返回 500 错误"
,
gen
.
ErrorMessage
)
require
.
NotNil
(
t
,
gen
.
CompletedAt
)
}
func
TestMarkFailed_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
999
,
"error"
)
require
.
Error
(
t
,
err
)
}
func
TestMarkFailed_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
1
,
"err"
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkCancelled ====================
func
TestMarkCancelled_Pending
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
1
]
.
Status
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
CompletedAt
)
}
func
TestMarkCancelled_Generating
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestMarkCancelled_Completed
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrSoraGenerationNotActive
)
}
func
TestMarkCancelled_Failed
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusFailed
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_AlreadyCancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCancelled
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
999
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
// ==================== GetByID ====================
func
TestGetByID_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
Model
:
"sora2-landscape-10s"
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
ID
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
gen
.
Model
)
}
func
TestGetByID_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权访问"
)
}
func
TestGetByID_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
}
// ==================== List ====================
func
TestList_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
MediaType
:
"video"
}
repo
.
gens
[
2
]
=
&
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Status
:
SoraGenStatusPending
,
MediaType
:
"image"
}
repo
.
gens
[
3
]
=
&
SoraGeneration
{
ID
:
3
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
,
MediaType
:
"video"
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gens
,
total
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
,
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
gens
,
2
)
// 只有 userID=1 的
require
.
Equal
(
t
,
int64
(
2
),
total
)
}
func
TestList_DefaultPagination
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
})
require
.
NoError
(
t
,
err
)
}
func
TestList_MaxPageSize
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// pageSize > 100 → 应限制为 100
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
,
Page
:
1
,
PageSize
:
200
})
require
.
NoError
(
t
,
err
)
}
func
TestList_Error
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
listErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
})
require
.
Error
(
t
,
err
)
}
// ==================== Delete ====================
func
TestDelete_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
StorageType
:
SoraStorageTypeUpstream
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDelete_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权删除"
)
}
func
TestDelete_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
}
func
TestDelete_S3Cleanup_NilS3
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
}}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// s3Storage 为 nil,跳过清理
}
func
TestDelete_QuotaRelease_NilQuota
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
1024
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// quotaService 为 nil,跳过释放
}
func
TestDelete_NonS3NoCleanup
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeLocal
,
FileSizeBytes
:
1024
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
}
func
TestDelete_DeleteRepoError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeUpstream
}
repo
.
deleteErr
=
fmt
.
Errorf
(
"delete failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
}
// ==================== CountActiveByUser ====================
func
TestCountActiveByUser_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
gens
[
2
]
=
&
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
gens
[
3
]
=
&
SoraGeneration
{
ID
:
3
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
// 不算
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
count
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
count
)
}
func
TestCountActiveByUser_NoActive
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
count
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
count
)
}
func
TestCountActiveByUser_Error
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
countErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
_
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
// ==================== ResolveMediaURLs ====================
func
TestResolveMediaURLs_NilGen
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
nil
))
}
func
TestResolveMediaURLs_NonS3
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeUpstream
,
MediaURL
:
"https://original.com/v.mp4"
}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
require
.
Equal
(
t
,
"https://original.com/v.mp4"
,
gen
.
MediaURL
)
// 不变
}
func
TestResolveMediaURLs_S3NilStorage
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
}}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
}
func
TestResolveMediaURLs_Local
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeLocal
,
MediaURL
:
"/video/2024/01/01/file.mp4"
}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
require
.
Equal
(
t
,
"/video/2024/01/01/file.mp4"
,
gen
.
MediaURL
)
// 不变
}
// ==================== 状态流转完整测试 ====================
func
TestStatusTransition_PendingToCompletedFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 1. 创建 pending
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusPending
,
gen
.
Status
)
// 2. 标记 generating
err
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
"task-123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusGenerating
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
// 3. 标记 completed
err
=
svc
.
MarkCompleted
(
context
.
Background
(),
gen
.
ID
,
"https://s3.com/video.mp4"
,
nil
,
SoraStorageTypeS3
,
[]
string
{
"key"
},
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCompleted
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
func
TestStatusTransition_PendingToFailedFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
_
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
""
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
gen
.
ID
,
"上游超时"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusFailed
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
require
.
Equal
(
t
,
"上游超时"
,
repo
.
gens
[
gen
.
ID
]
.
ErrorMessage
)
}
func
TestStatusTransition_PendingToCancelledFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
gen
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
func
TestStatusTransition_GeneratingToCancelledFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
_
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
""
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
gen
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
// ==================== 权限隔离测试 ====================
func
TestUserIsolation_CannotAccessOthersRecord
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
// 用户 2 尝试访问用户 1 的记录
_
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
gen
.
ID
,
2
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权访问"
)
}
func
TestUserIsolation_CannotDeleteOthersRecord
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
err
:=
svc
.
Delete
(
context
.
Background
(),
gen
.
ID
,
2
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权删除"
)
}
// ==================== Delete: S3 清理 + 配额释放路径 ====================
func
TestDelete_S3Cleanup_WithS3Storage
(
t
*
testing
.
T
)
{
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
// 验证 Delete 仍然成功(S3 错误只是记录日志)
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/abc.mp4"
},
}
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
repo
,
s3Storage
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// S3 清理失败不影响删除
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDelete_QuotaRelease_WithQuotaService
(
t
*
testing
.
T
)
{
// 有配额服务时,删除 S3 类型记录会释放配额
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
1048576
,
// 1MB
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
2097152
}
// 2MB
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// 配额应被释放: 2MB - 1MB = 1MB
require
.
Equal
(
t
,
int64
(
1048576
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_S3Cleanup_And_QuotaRelease
(
t
*
testing
.
T
)
{
// S3 清理 + 配额释放同时触发
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
},
FileSizeBytes
:
512
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
repo
,
s3Storage
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
require
.
Equal
(
t
,
int64
(
512
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_QuotaRelease_LocalStorage
(
t
*
testing
.
T
)
{
// 本地存储同样需要释放配额
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeLocal
,
FileSizeBytes
:
1024
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
2048
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_QuotaRelease_ZeroFileSize
(
t
*
testing
.
T
)
{
// FileSizeBytes=0 跳过配额释放
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
0
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
func
TestResolveMediaURLs_S3_CDN_SingleKey
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/video.mp4"
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/video.mp4"
,
gen
.
MediaURL
)
}
func
TestResolveMediaURLs_S3_CDN_MultipleKeys
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com/"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/img1.png"
,
"sora/1/2024/01/01/img2.png"
,
"sora/1/2024/01/01/img3.png"
,
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
// 主 URL 更新为第一个 key 的 CDN URL
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img1.png"
,
gen
.
MediaURL
)
// 多图 URLs 全部更新
require
.
Len
(
t
,
gen
.
MediaURLs
,
3
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img1.png"
,
gen
.
MediaURLs
[
0
])
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img2.png"
,
gen
.
MediaURLs
[
1
])
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img3.png"
,
gen
.
MediaURLs
[
2
])
}
func
TestResolveMediaURLs_S3_EmptyKeys
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"original"
,
gen
.
MediaURL
)
// 不变
}
func
TestResolveMediaURLs_S3_GetAccessURL_Error
(
t
*
testing
.
T
)
{
// 使用无 settingService 的 S3 Storage,getClient 会失败
s3Storage
:=
newS3StorageFailingDelete
()
// 同样 GetAccessURL 也会失败
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/video.mp4"
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
Error
(
t
,
err
)
// GetAccessURL 失败应传播错误
}
func
TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond
(
t
*
testing
.
T
)
{
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/img1.png"
,
"sora/1/2024/01/01/img2.png"
,
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
Error
(
t
,
err
)
// 第一个 key 的 GetAccessURL 就会失败
}
backend/internal/service/sora_media_cleanup_service.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
var
soraCleanupCronParser
=
cron
.
NewParser
(
cron
.
Minute
|
cron
.
Hour
|
cron
.
Dom
|
cron
.
Month
|
cron
.
Dow
)
// SoraMediaCleanupService 定期清理本地媒体文件
type
SoraMediaCleanupService
struct
{
storage
*
SoraMediaStorage
cfg
*
config
.
Config
cron
*
cron
.
Cron
startOnce
sync
.
Once
stopOnce
sync
.
Once
}
func
NewSoraMediaCleanupService
(
storage
*
SoraMediaStorage
,
cfg
*
config
.
Config
)
*
SoraMediaCleanupService
{
return
&
SoraMediaCleanupService
{
storage
:
storage
,
cfg
:
cfg
,
}
}
func
(
s
*
SoraMediaCleanupService
)
Start
()
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
}
if
!
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
Enabled
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (disabled)"
)
return
}
if
s
.
storage
==
nil
||
!
s
.
storage
.
Enabled
()
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (storage disabled)"
)
return
}
s
.
startOnce
.
Do
(
func
()
{
schedule
:=
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
Schedule
)
if
schedule
==
""
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (empty schedule)"
)
return
}
loc
:=
time
.
Local
if
strings
.
TrimSpace
(
s
.
cfg
.
Timezone
)
!=
""
{
if
parsed
,
err
:=
time
.
LoadLocation
(
strings
.
TrimSpace
(
s
.
cfg
.
Timezone
));
err
==
nil
&&
parsed
!=
nil
{
loc
=
parsed
}
}
c
:=
cron
.
New
(
cron
.
WithParser
(
soraCleanupCronParser
),
cron
.
WithLocation
(
loc
))
if
_
,
err
:=
c
.
AddFunc
(
schedule
,
func
()
{
s
.
runCleanup
()
});
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (invalid schedule=%q): %v"
,
schedule
,
err
)
return
}
s
.
cron
=
c
s
.
cron
.
Start
()
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] started (schedule=%q tz=%s)"
,
schedule
,
loc
.
String
())
})
}
func
(
s
*
SoraMediaCleanupService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
if
s
.
cron
!=
nil
{
ctx
:=
s
.
cron
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
case
<-
time
.
After
(
3
*
time
.
Second
)
:
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] cron stop timed out"
)
}
}
})
}
func
(
s
*
SoraMediaCleanupService
)
runCleanup
()
{
if
s
.
cfg
==
nil
||
s
.
storage
==
nil
{
return
}
retention
:=
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
RetentionDays
if
retention
<=
0
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] skipped (retention_days=%d)"
,
retention
)
return
}
cutoff
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
retention
)
deleted
:=
0
roots
:=
[]
string
{
s
.
storage
.
ImageRoot
(),
s
.
storage
.
VideoRoot
()}
for
_
,
root
:=
range
roots
{
if
root
==
""
{
continue
}
_
=
filepath
.
Walk
(
root
,
func
(
p
string
,
info
os
.
FileInfo
,
err
error
)
error
{
if
err
!=
nil
{
return
nil
}
if
info
.
IsDir
()
{
return
nil
}
if
info
.
ModTime
()
.
Before
(
cutoff
)
{
if
rmErr
:=
os
.
Remove
(
p
);
rmErr
==
nil
{
deleted
++
}
}
return
nil
})
}
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] cleanup finished, deleted=%d"
,
deleted
)
}
backend/internal/service/sora_media_cleanup_service_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
service
import
(
"os"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSoraMediaCleanupService_RunCleanup_NilCfg
(
t
*
testing
.
T
)
{
storage
:=
&
SoraMediaStorage
{}
svc
:=
&
SoraMediaCleanupService
{
storage
:
storage
,
cfg
:
nil
}
// 不应 panic
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_RunCleanup_NilStorage
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
&
SoraMediaCleanupService
{
storage
:
nil
,
cfg
:
cfg
}
// 不应 panic
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_RunCleanup_ZeroRetention
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
RetentionDays
:
0
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
// retention=0 应跳过清理
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_Start_NilCfg
(
t
*
testing
.
T
)
{
svc
:=
NewSoraMediaCleanupService
(
nil
,
nil
)
svc
.
Start
()
// cfg == nil 时应直接返回
}
func
TestSoraMediaCleanupService_Start_StorageDisabled
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
},
},
},
}
svc
:=
NewSoraMediaCleanupService
(
nil
,
cfg
)
svc
.
Start
()
// storage == nil 时应直接返回
}
func
TestSoraMediaCleanupService_Start_WithTimezone
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Timezone
:
"Asia/Shanghai"
,
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"0 3 * * *"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
}
func
TestSoraMediaCleanupService_Start_Disabled
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
false
,
},
},
},
}
svc
:=
NewSoraMediaCleanupService
(
nil
,
cfg
)
svc
.
Start
()
// 不应 panic,也不应启动 cron
}
func
TestSoraMediaCleanupService_Start_NilSelf
(
t
*
testing
.
T
)
{
var
svc
*
SoraMediaCleanupService
svc
.
Start
()
// 不应 panic
}
func
TestSoraMediaCleanupService_Start_EmptySchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
""
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
// 空 schedule 不应启动
}
func
TestSoraMediaCleanupService_Start_InvalidSchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"invalid-cron"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
// 无效 schedule 不应 panic
}
func
TestSoraMediaCleanupService_Start_ValidSchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"0 3 * * *"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
}
func
TestSoraMediaCleanupService_Stop_NilSelf
(
t
*
testing
.
T
)
{
var
svc
*
SoraMediaCleanupService
svc
.
Stop
()
// 不应 panic
}
func
TestSoraMediaCleanupService_Stop_WithoutStart
(
t
*
testing
.
T
)
{
svc
:=
NewSoraMediaCleanupService
(
nil
,
&
config
.
Config
{})
svc
.
Stop
()
// cron 未启动时 Stop 不应 panic
}
func
TestSoraMediaCleanupService_RunCleanup
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
RetentionDays
:
1
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
require
.
NoError
(
t
,
storage
.
EnsureLocalDirs
())
oldImage
:=
filepath
.
Join
(
storage
.
ImageRoot
(),
"old.png"
)
newVideo
:=
filepath
.
Join
(
storage
.
VideoRoot
(),
"new.mp4"
)
require
.
NoError
(
t
,
os
.
WriteFile
(
oldImage
,
[]
byte
(
"old"
),
0
o644
))
require
.
NoError
(
t
,
os
.
WriteFile
(
newVideo
,
[]
byte
(
"new"
),
0
o644
))
oldTime
:=
time
.
Now
()
.
Add
(
-
48
*
time
.
Hour
)
require
.
NoError
(
t
,
os
.
Chtimes
(
oldImage
,
oldTime
,
oldTime
))
cleanup
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
cleanup
.
runCleanup
()
require
.
NoFileExists
(
t
,
oldImage
)
require
.
FileExists
(
t
,
newVideo
)
}
backend/internal/service/sora_media_sign.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strconv"
"strings"
)
// SignSoraMediaURL 生成 Sora 媒体临时签名
func
SignSoraMediaURL
(
path
string
,
query
string
,
expires
int64
,
key
string
)
string
{
key
=
strings
.
TrimSpace
(
key
)
if
key
==
""
{
return
""
}
mac
:=
hmac
.
New
(
sha256
.
New
,
[]
byte
(
key
))
if
_
,
err
:=
mac
.
Write
([]
byte
(
buildSoraMediaSignPayload
(
path
,
query
)));
err
!=
nil
{
return
""
}
if
_
,
err
:=
mac
.
Write
([]
byte
(
"|"
));
err
!=
nil
{
return
""
}
if
_
,
err
:=
mac
.
Write
([]
byte
(
strconv
.
FormatInt
(
expires
,
10
)));
err
!=
nil
{
return
""
}
return
hex
.
EncodeToString
(
mac
.
Sum
(
nil
))
}
// VerifySoraMediaURL 校验 Sora 媒体签名
func
VerifySoraMediaURL
(
path
string
,
query
string
,
expires
int64
,
signature
string
,
key
string
)
bool
{
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
return
false
}
expected
:=
SignSoraMediaURL
(
path
,
query
,
expires
,
key
)
if
expected
==
""
{
return
false
}
return
hmac
.
Equal
([]
byte
(
signature
),
[]
byte
(
expected
))
}
func
buildSoraMediaSignPayload
(
path
string
,
query
string
)
string
{
if
strings
.
TrimSpace
(
query
)
==
""
{
return
path
}
return
path
+
"?"
+
query
}
backend/internal/service/sora_media_sign_test.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
"testing"
func
TestSoraMediaSignVerify
(
t
*
testing
.
T
)
{
key
:=
"test-key"
path
:=
"/tmp/abc.png"
query
:=
"a=1&b=2"
expires
:=
int64
(
1700000000
)
signature
:=
SignSoraMediaURL
(
path
,
query
,
expires
,
key
)
if
signature
==
""
{
t
.
Fatal
(
"签名为空"
)
}
if
!
VerifySoraMediaURL
(
path
,
query
,
expires
,
signature
,
key
)
{
t
.
Fatal
(
"签名校验失败"
)
}
if
VerifySoraMediaURL
(
path
,
"a=1"
,
expires
,
signature
,
key
)
{
t
.
Fatal
(
"签名参数不同仍然通过"
)
}
if
VerifySoraMediaURL
(
path
,
query
,
expires
+
1
,
signature
,
key
)
{
t
.
Fatal
(
"签名过期校验未失败"
)
}
}
func
TestSoraMediaSignWithEmptyKey
(
t
*
testing
.
T
)
{
signature
:=
SignSoraMediaURL
(
"/tmp/a.png"
,
"a=1"
,
1
,
""
)
if
signature
!=
""
{
t
.
Fatalf
(
"空密钥不应生成签名"
)
}
if
VerifySoraMediaURL
(
"/tmp/a.png"
,
"a=1"
,
1
,
"sig"
,
""
)
{
t
.
Fatalf
(
"空密钥不应通过校验"
)
}
}
backend/internal/service/sora_media_storage.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
)
const
(
soraStorageDefaultRoot
=
"/app/data/sora"
)
// SoraMediaStorage 负责下载并落地 Sora 媒体
type
SoraMediaStorage
struct
{
cfg
*
config
.
Config
root
string
imageRoot
string
videoRoot
string
downloadTimeout
time
.
Duration
maxDownloadBytes
int64
fallbackToUpstream
bool
debug
bool
sem
chan
struct
{}
ready
bool
}
func
NewSoraMediaStorage
(
cfg
*
config
.
Config
)
*
SoraMediaStorage
{
storage
:=
&
SoraMediaStorage
{
cfg
:
cfg
}
storage
.
refreshConfig
()
if
storage
.
Enabled
()
{
if
err
:=
storage
.
EnsureLocalDirs
();
err
!=
nil
{
log
.
Printf
(
"[SoraStorage] 初始化失败: %v"
,
err
)
}
}
return
storage
}
func
(
s
*
SoraMediaStorage
)
Enabled
()
bool
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
false
}
return
strings
.
ToLower
(
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
Type
))
==
"local"
}
func
(
s
*
SoraMediaStorage
)
Root
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
root
}
func
(
s
*
SoraMediaStorage
)
ImageRoot
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
imageRoot
}
func
(
s
*
SoraMediaStorage
)
VideoRoot
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
videoRoot
}
func
(
s
*
SoraMediaStorage
)
refreshConfig
()
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
}
root
:=
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
LocalPath
)
if
root
==
""
{
root
=
soraStorageDefaultRoot
}
root
=
filepath
.
Clean
(
root
)
if
!
filepath
.
IsAbs
(
root
)
{
if
absRoot
,
err
:=
filepath
.
Abs
(
root
);
err
==
nil
{
root
=
absRoot
}
}
s
.
root
=
root
s
.
imageRoot
=
filepath
.
Join
(
root
,
"image"
)
s
.
videoRoot
=
filepath
.
Join
(
root
,
"video"
)
maxConcurrent
:=
s
.
cfg
.
Sora
.
Storage
.
MaxConcurrentDownloads
if
maxConcurrent
<=
0
{
maxConcurrent
=
4
}
timeoutSeconds
:=
s
.
cfg
.
Sora
.
Storage
.
DownloadTimeoutSeconds
if
timeoutSeconds
<=
0
{
timeoutSeconds
=
120
}
s
.
downloadTimeout
=
time
.
Duration
(
timeoutSeconds
)
*
time
.
Second
maxBytes
:=
s
.
cfg
.
Sora
.
Storage
.
MaxDownloadBytes
if
maxBytes
<=
0
{
maxBytes
=
200
<<
20
}
s
.
maxDownloadBytes
=
maxBytes
s
.
fallbackToUpstream
=
s
.
cfg
.
Sora
.
Storage
.
FallbackToUpstream
s
.
debug
=
s
.
cfg
.
Sora
.
Storage
.
Debug
s
.
sem
=
make
(
chan
struct
{},
maxConcurrent
)
}
// EnsureLocalDirs 创建并校验本地目录
func
(
s
*
SoraMediaStorage
)
EnsureLocalDirs
()
error
{
if
s
==
nil
||
!
s
.
Enabled
()
{
return
nil
}
if
err
:=
os
.
MkdirAll
(
s
.
imageRoot
,
0
o755
);
err
!=
nil
{
return
fmt
.
Errorf
(
"create image dir: %w"
,
err
)
}
if
err
:=
os
.
MkdirAll
(
s
.
videoRoot
,
0
o755
);
err
!=
nil
{
return
fmt
.
Errorf
(
"create video dir: %w"
,
err
)
}
s
.
ready
=
true
return
nil
}
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
func
(
s
*
SoraMediaStorage
)
StoreFromURLs
(
ctx
context
.
Context
,
mediaType
string
,
urls
[]
string
)
([]
string
,
error
)
{
if
len
(
urls
)
==
0
{
return
nil
,
nil
}
if
s
==
nil
||
!
s
.
Enabled
()
{
return
urls
,
nil
}
if
!
s
.
ready
{
if
err
:=
s
.
EnsureLocalDirs
();
err
!=
nil
{
return
nil
,
err
}
}
results
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
raw
:=
range
urls
{
relative
,
err
:=
s
.
downloadAndStore
(
ctx
,
mediaType
,
raw
)
if
err
!=
nil
{
if
s
.
fallbackToUpstream
{
results
=
append
(
results
,
raw
)
continue
}
return
nil
,
err
}
results
=
append
(
results
,
relative
)
}
return
results
,
nil
}
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
func
(
s
*
SoraMediaStorage
)
TotalSizeByRelativePaths
(
paths
[]
string
)
(
int64
,
error
)
{
if
s
==
nil
||
len
(
paths
)
==
0
{
return
0
,
nil
}
var
total
int64
for
_
,
p
:=
range
paths
{
localPath
,
err
:=
s
.
resolveLocalPath
(
p
)
if
err
!=
nil
{
continue
}
info
,
err
:=
os
.
Stat
(
localPath
)
if
err
!=
nil
{
if
os
.
IsNotExist
(
err
)
{
continue
}
return
0
,
err
}
if
info
.
Mode
()
.
IsRegular
()
{
total
+=
info
.
Size
()
}
}
return
total
,
nil
}
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
func
(
s
*
SoraMediaStorage
)
DeleteByRelativePaths
(
paths
[]
string
)
error
{
if
s
==
nil
||
len
(
paths
)
==
0
{
return
nil
}
var
lastErr
error
for
_
,
p
:=
range
paths
{
localPath
,
err
:=
s
.
resolveLocalPath
(
p
)
if
err
!=
nil
{
continue
}
if
err
:=
os
.
Remove
(
localPath
);
err
!=
nil
&&
!
os
.
IsNotExist
(
err
)
{
lastErr
=
err
}
}
return
lastErr
}
func
(
s
*
SoraMediaStorage
)
resolveLocalPath
(
relativePath
string
)
(
string
,
error
)
{
if
s
==
nil
||
strings
.
TrimSpace
(
relativePath
)
==
""
{
return
""
,
errors
.
New
(
"empty path"
)
}
cleaned
:=
path
.
Clean
(
relativePath
)
if
!
strings
.
HasPrefix
(
cleaned
,
"/image/"
)
&&
!
strings
.
HasPrefix
(
cleaned
,
"/video/"
)
{
return
""
,
errors
.
New
(
"not a local media path"
)
}
if
strings
.
TrimSpace
(
s
.
root
)
==
""
{
return
""
,
errors
.
New
(
"storage root not configured"
)
}
relative
:=
strings
.
TrimPrefix
(
cleaned
,
"/"
)
return
filepath
.
Join
(
s
.
root
,
filepath
.
FromSlash
(
relative
)),
nil
}
func
(
s
*
SoraMediaStorage
)
downloadAndStore
(
ctx
context
.
Context
,
mediaType
,
rawURL
string
)
(
string
,
error
)
{
if
strings
.
TrimSpace
(
rawURL
)
==
""
{
return
""
,
errors
.
New
(
"empty url"
)
}
root
:=
s
.
imageRoot
if
mediaType
==
"video"
{
root
=
s
.
videoRoot
}
if
root
==
""
{
return
""
,
errors
.
New
(
"storage root not configured"
)
}
retries
:=
3
for
attempt
:=
1
;
attempt
<=
retries
;
attempt
++
{
release
,
err
:=
s
.
acquire
(
ctx
)
if
err
!=
nil
{
return
""
,
err
}
relative
,
err
:=
s
.
downloadOnce
(
ctx
,
root
,
mediaType
,
rawURL
)
release
()
if
err
==
nil
{
return
relative
,
nil
}
if
s
.
debug
{
log
.
Printf
(
"[SoraStorage] 下载失败(%d/%d): %s err=%v"
,
attempt
,
retries
,
sanitizeMediaLogURL
(
rawURL
),
err
)
}
if
attempt
<
retries
{
time
.
Sleep
(
time
.
Duration
(
attempt
*
attempt
)
*
time
.
Second
)
continue
}
return
""
,
err
}
return
""
,
errors
.
New
(
"download retries exhausted"
)
}
func
(
s
*
SoraMediaStorage
)
downloadOnce
(
ctx
context
.
Context
,
root
,
mediaType
,
rawURL
string
)
(
string
,
error
)
{
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
rawURL
,
nil
)
if
err
!=
nil
{
return
""
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
s
.
downloadTimeout
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
return
""
,
fmt
.
Errorf
(
"download failed: %d %s"
,
resp
.
StatusCode
,
string
(
body
))
}
ext
:=
normalizeSoraFileExt
(
fileExtFromURL
(
rawURL
))
if
ext
==
""
{
ext
=
normalizeSoraFileExt
(
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
)))
}
if
ext
==
""
{
ext
=
".bin"
}
if
s
.
maxDownloadBytes
>
0
&&
resp
.
ContentLength
>
s
.
maxDownloadBytes
{
return
""
,
fmt
.
Errorf
(
"download size exceeds limit: %d"
,
resp
.
ContentLength
)
}
storageRoot
,
err
:=
os
.
OpenRoot
(
root
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
storageRoot
.
Close
()
}()
datePath
:=
time
.
Now
()
.
Format
(
"2006/01/02"
)
datePathFS
:=
filepath
.
FromSlash
(
datePath
)
if
err
:=
storageRoot
.
MkdirAll
(
datePathFS
,
0
o755
);
err
!=
nil
{
return
""
,
err
}
filename
:=
uuid
.
NewString
()
+
ext
filePath
:=
filepath
.
Join
(
datePathFS
,
filename
)
out
,
err
:=
storageRoot
.
OpenFile
(
filePath
,
os
.
O_CREATE
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
0
o644
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
out
.
Close
()
}()
limited
:=
io
.
LimitReader
(
resp
.
Body
,
s
.
maxDownloadBytes
+
1
)
written
,
err
:=
io
.
Copy
(
out
,
limited
)
if
err
!=
nil
{
removePartialDownload
(
storageRoot
,
filePath
)
return
""
,
err
}
if
s
.
maxDownloadBytes
>
0
&&
written
>
s
.
maxDownloadBytes
{
removePartialDownload
(
storageRoot
,
filePath
)
return
""
,
fmt
.
Errorf
(
"download size exceeds limit: %d"
,
written
)
}
relative
:=
path
.
Join
(
"/"
,
mediaType
,
datePath
,
filename
)
if
s
.
debug
{
log
.
Printf
(
"[SoraStorage] 已落地 %s -> %s"
,
sanitizeMediaLogURL
(
rawURL
),
relative
)
}
return
relative
,
nil
}
func
(
s
*
SoraMediaStorage
)
acquire
(
ctx
context
.
Context
)
(
func
(),
error
)
{
if
s
.
sem
==
nil
{
return
func
()
{},
nil
}
select
{
case
s
.
sem
<-
struct
{}{}
:
return
func
()
{
<-
s
.
sem
},
nil
case
<-
ctx
.
Done
()
:
return
nil
,
ctx
.
Err
()
}
}
func
fileExtFromURL
(
raw
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
""
}
ext
:=
path
.
Ext
(
parsed
.
Path
)
return
strings
.
ToLower
(
ext
)
}
func
fileExtFromContentType
(
ct
string
)
string
{
if
ct
==
""
{
return
""
}
if
exts
,
err
:=
mime
.
ExtensionsByType
(
ct
);
err
==
nil
&&
len
(
exts
)
>
0
{
return
strings
.
ToLower
(
exts
[
0
])
}
return
""
}
func
normalizeSoraFileExt
(
ext
string
)
string
{
ext
=
strings
.
ToLower
(
strings
.
TrimSpace
(
ext
))
switch
ext
{
case
".png"
,
".jpg"
,
".jpeg"
,
".gif"
,
".webp"
,
".bmp"
,
".svg"
,
".tif"
,
".tiff"
,
".heic"
,
".mp4"
,
".mov"
,
".webm"
,
".m4v"
,
".avi"
,
".mkv"
,
".3gp"
,
".flv"
:
return
ext
default
:
return
""
}
}
func
removePartialDownload
(
root
*
os
.
Root
,
filePath
string
)
{
if
root
==
nil
||
strings
.
TrimSpace
(
filePath
)
==
""
{
return
}
_
=
root
.
Remove
(
filePath
)
}
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
func
sanitizeMediaLogURL
(
rawURL
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
rawURL
)
if
err
!=
nil
{
if
len
(
rawURL
)
>
80
{
return
rawURL
[
:
80
]
+
"..."
}
return
rawURL
}
safe
:=
parsed
.
Scheme
+
"://"
+
parsed
.
Host
+
parsed
.
Path
if
len
(
safe
)
>
120
{
return
safe
[
:
120
]
+
"..."
}
return
safe
}
backend/internal/service/sora_media_storage_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
service
import
(
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSoraMediaStorage_StoreFromURLs
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"image/png"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"data"
))
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
MaxConcurrentDownloads
:
1
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
urls
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
server
.
URL
+
"/img.png"
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
urls
,
1
)
require
.
True
(
t
,
strings
.
HasPrefix
(
urls
[
0
],
"/image/"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
urls
[
0
],
".png"
))
localPath
:=
filepath
.
Join
(
tmpDir
,
filepath
.
FromSlash
(
strings
.
TrimPrefix
(
urls
[
0
],
"/"
)))
require
.
FileExists
(
t
,
localPath
)
}
func
TestSoraMediaStorage_FallbackToUpstream
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
FallbackToUpstream
:
true
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
url
:=
server
.
URL
+
"/broken.png"
urls
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
url
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
url
},
urls
)
}
func
TestSoraMediaStorage_MaxDownloadBytes
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"image/png"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"too-large"
))
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
MaxDownloadBytes
:
1
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
_
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
server
.
URL
+
"/img.png"
})
require
.
Error
(
t
,
err
)
}
func
TestNormalizeSoraFileExt
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
".png"
,
normalizeSoraFileExt
(
".PNG"
))
require
.
Equal
(
t
,
".mp4"
,
normalizeSoraFileExt
(
".mp4"
))
require
.
Equal
(
t
,
""
,
normalizeSoraFileExt
(
"../../etc/passwd"
))
require
.
Equal
(
t
,
""
,
normalizeSoraFileExt
(
".php"
))
}
func
TestRemovePartialDownload
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
root
,
err
:=
os
.
OpenRoot
(
tmpDir
)
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
root
.
Close
()
}()
filePath
:=
"partial.bin"
f
,
err
:=
root
.
OpenFile
(
filePath
,
os
.
O_CREATE
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
0
o600
)
require
.
NoError
(
t
,
err
)
_
,
_
=
f
.
WriteString
(
"partial"
)
_
=
f
.
Close
()
removePartialDownload
(
root
,
filePath
)
_
,
err
=
root
.
Stat
(
filePath
)
require
.
Error
(
t
,
err
)
require
.
True
(
t
,
os
.
IsNotExist
(
err
))
}
backend/internal/service/sora_models.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"regexp"
"sort"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
// SoraModelConfig Sora 模型配置
type
SoraModelConfig
struct
{
Type
string
Width
int
Height
int
Orientation
string
Frames
int
Model
string
Size
string
RequirePro
bool
// Prompt-enhance 专用参数
ExpansionLevel
string
DurationS
int
}
var
soraModelConfigs
=
map
[
string
]
SoraModelConfig
{
"gpt-image"
:
{
Type
:
"image"
,
Width
:
360
,
Height
:
360
,
},
"gpt-image-landscape"
:
{
Type
:
"image"
,
Width
:
540
,
Height
:
360
,
},
"gpt-image-portrait"
:
{
Type
:
"image"
,
Width
:
360
,
Height
:
540
,
},
"sora2-landscape-10s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
300
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-portrait-10s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
300
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-landscape-15s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
450
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-portrait-15s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
450
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-landscape-25s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
750
,
Model
:
"sy_8"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2-portrait-25s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
750
,
Model
:
"sy_8"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-landscape-10s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-portrait-10s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-landscape-15s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-portrait-15s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-landscape-25s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
750
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-portrait-25s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
750
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-hd-landscape-10s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"sora2pro-hd-portrait-10s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"sora2pro-hd-landscape-15s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"sora2pro-hd-portrait-15s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"prompt-enhance-short-10s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"short"
,
DurationS
:
10
,
},
"prompt-enhance-short-15s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"short"
,
DurationS
:
15
,
},
"prompt-enhance-short-20s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"short"
,
DurationS
:
20
,
},
"prompt-enhance-medium-10s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"medium"
,
DurationS
:
10
,
},
"prompt-enhance-medium-15s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"medium"
,
DurationS
:
15
,
},
"prompt-enhance-medium-20s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"medium"
,
DurationS
:
20
,
},
"prompt-enhance-long-10s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"long"
,
DurationS
:
10
,
},
"prompt-enhance-long-15s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"long"
,
DurationS
:
15
,
},
"prompt-enhance-long-20s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"long"
,
DurationS
:
20
,
},
}
var
soraModelIDs
=
[]
string
{
"gpt-image"
,
"gpt-image-landscape"
,
"gpt-image-portrait"
,
"sora2-landscape-10s"
,
"sora2-portrait-10s"
,
"sora2-landscape-15s"
,
"sora2-portrait-15s"
,
"sora2-landscape-25s"
,
"sora2-portrait-25s"
,
"sora2pro-landscape-10s"
,
"sora2pro-portrait-10s"
,
"sora2pro-landscape-15s"
,
"sora2pro-portrait-15s"
,
"sora2pro-landscape-25s"
,
"sora2pro-portrait-25s"
,
"sora2pro-hd-landscape-10s"
,
"sora2pro-hd-portrait-10s"
,
"sora2pro-hd-landscape-15s"
,
"sora2pro-hd-portrait-15s"
,
"prompt-enhance-short-10s"
,
"prompt-enhance-short-15s"
,
"prompt-enhance-short-20s"
,
"prompt-enhance-medium-10s"
,
"prompt-enhance-medium-15s"
,
"prompt-enhance-medium-20s"
,
"prompt-enhance-long-10s"
,
"prompt-enhance-long-15s"
,
"prompt-enhance-long-20s"
,
}
// GetSoraModelConfig 返回 Sora 模型配置
func
GetSoraModelConfig
(
model
string
)
(
SoraModelConfig
,
bool
)
{
key
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
cfg
,
ok
:=
soraModelConfigs
[
key
]
return
cfg
,
ok
}
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
type
SoraModelFamily
struct
{
ID
string
`json:"id"`
Name
string
`json:"name"`
Type
string
`json:"type"`
Orientations
[]
string
`json:"orientations"`
Durations
[]
int
`json:"durations,omitempty"`
}
var
(
videoSuffixRe
=
regexp
.
MustCompile
(
`-(landscape|portrait)-(\d+)s$`
)
imageSuffixRe
=
regexp
.
MustCompile
(
`-(landscape|portrait)$`
)
soraFamilyNames
=
map
[
string
]
string
{
"sora2"
:
"Sora 2"
,
"sora2pro"
:
"Sora 2 Pro"
,
"sora2pro-hd"
:
"Sora 2 Pro HD"
,
"gpt-image"
:
"GPT Image"
,
}
)
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
func
BuildSoraModelFamilies
()
[]
SoraModelFamily
{
type
familyData
struct
{
modelType
string
orientations
map
[
string
]
bool
durations
map
[
int
]
bool
}
families
:=
make
(
map
[
string
]
*
familyData
)
for
id
,
cfg
:=
range
soraModelConfigs
{
if
cfg
.
Type
==
"prompt_enhance"
{
continue
}
var
famID
,
orientation
string
var
duration
int
switch
cfg
.
Type
{
case
"video"
:
if
m
:=
videoSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
duration
,
_
=
strconv
.
Atoi
(
m
[
2
])
}
case
"image"
:
if
m
:=
imageSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
}
else
{
famID
=
id
orientation
=
"square"
}
}
if
famID
==
""
{
continue
}
fd
,
ok
:=
families
[
famID
]
if
!
ok
{
fd
=
&
familyData
{
modelType
:
cfg
.
Type
,
orientations
:
make
(
map
[
string
]
bool
),
durations
:
make
(
map
[
int
]
bool
),
}
families
[
famID
]
=
fd
}
if
orientation
!=
""
{
fd
.
orientations
[
orientation
]
=
true
}
if
duration
>
0
{
fd
.
durations
[
duration
]
=
true
}
}
// 排序:视频在前、图像在后,同类按名称排序
famIDs
:=
make
([]
string
,
0
,
len
(
families
))
for
id
:=
range
families
{
famIDs
=
append
(
famIDs
,
id
)
}
sort
.
Slice
(
famIDs
,
func
(
i
,
j
int
)
bool
{
fi
,
fj
:=
families
[
famIDs
[
i
]],
families
[
famIDs
[
j
]]
if
fi
.
modelType
!=
fj
.
modelType
{
return
fi
.
modelType
==
"video"
}
return
famIDs
[
i
]
<
famIDs
[
j
]
})
result
:=
make
([]
SoraModelFamily
,
0
,
len
(
famIDs
))
for
_
,
famID
:=
range
famIDs
{
fd
:=
families
[
famID
]
fam
:=
SoraModelFamily
{
ID
:
famID
,
Name
:
soraFamilyNames
[
famID
],
Type
:
fd
.
modelType
,
}
if
fam
.
Name
==
""
{
fam
.
Name
=
famID
}
for
o
:=
range
fd
.
orientations
{
fam
.
Orientations
=
append
(
fam
.
Orientations
,
o
)
}
sort
.
Strings
(
fam
.
Orientations
)
for
d
:=
range
fd
.
durations
{
fam
.
Durations
=
append
(
fam
.
Durations
,
d
)
}
sort
.
Ints
(
fam
.
Durations
)
result
=
append
(
result
,
fam
)
}
return
result
}
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
// 通过命名约定自动识别视频/图像模型并分组。
func
BuildSoraModelFamiliesFromIDs
(
modelIDs
[]
string
)
[]
SoraModelFamily
{
type
familyData
struct
{
modelType
string
orientations
map
[
string
]
bool
durations
map
[
int
]
bool
}
families
:=
make
(
map
[
string
]
*
familyData
)
for
_
,
id
:=
range
modelIDs
{
id
=
strings
.
ToLower
(
strings
.
TrimSpace
(
id
))
if
id
==
""
||
strings
.
HasPrefix
(
id
,
"prompt-enhance"
)
{
continue
}
var
famID
,
orientation
,
modelType
string
var
duration
int
if
m
:=
videoSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
// 视频模型: {family}-{orientation}-{duration}s
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
duration
,
_
=
strconv
.
Atoi
(
m
[
2
])
modelType
=
"video"
}
else
if
m
:=
imageSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
// 图像模型(带方向): {family}-{orientation}
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
modelType
=
"image"
}
else
if
cfg
,
ok
:=
soraModelConfigs
[
id
];
ok
&&
cfg
.
Type
==
"image"
{
// 已知的无后缀图像模型(如 gpt-image)
famID
=
id
orientation
=
"square"
modelType
=
"image"
}
else
if
strings
.
Contains
(
id
,
"image"
)
{
// 未知但名称包含 image 的模型,推断为图像模型
famID
=
id
orientation
=
"square"
modelType
=
"image"
}
else
{
continue
}
if
famID
==
""
{
continue
}
fd
,
ok
:=
families
[
famID
]
if
!
ok
{
fd
=
&
familyData
{
modelType
:
modelType
,
orientations
:
make
(
map
[
string
]
bool
),
durations
:
make
(
map
[
int
]
bool
),
}
families
[
famID
]
=
fd
}
if
orientation
!=
""
{
fd
.
orientations
[
orientation
]
=
true
}
if
duration
>
0
{
fd
.
durations
[
duration
]
=
true
}
}
famIDs
:=
make
([]
string
,
0
,
len
(
families
))
for
id
:=
range
families
{
famIDs
=
append
(
famIDs
,
id
)
}
sort
.
Slice
(
famIDs
,
func
(
i
,
j
int
)
bool
{
fi
,
fj
:=
families
[
famIDs
[
i
]],
families
[
famIDs
[
j
]]
if
fi
.
modelType
!=
fj
.
modelType
{
return
fi
.
modelType
==
"video"
}
return
famIDs
[
i
]
<
famIDs
[
j
]
})
result
:=
make
([]
SoraModelFamily
,
0
,
len
(
famIDs
))
for
_
,
famID
:=
range
famIDs
{
fd
:=
families
[
famID
]
fam
:=
SoraModelFamily
{
ID
:
famID
,
Name
:
soraFamilyNames
[
famID
],
Type
:
fd
.
modelType
,
}
if
fam
.
Name
==
""
{
fam
.
Name
=
famID
}
for
o
:=
range
fd
.
orientations
{
fam
.
Orientations
=
append
(
fam
.
Orientations
,
o
)
}
sort
.
Strings
(
fam
.
Orientations
)
for
d
:=
range
fd
.
durations
{
fam
.
Durations
=
append
(
fam
.
Durations
,
d
)
}
sort
.
Ints
(
fam
.
Durations
)
result
=
append
(
result
,
fam
)
}
return
result
}
// DefaultSoraModels returns the default Sora model list.
func
DefaultSoraModels
(
cfg
*
config
.
Config
)
[]
openai
.
Model
{
models
:=
make
([]
openai
.
Model
,
0
,
len
(
soraModelIDs
))
for
_
,
id
:=
range
soraModelIDs
{
models
=
append
(
models
,
openai
.
Model
{
ID
:
id
,
Object
:
"model"
,
OwnedBy
:
"openai"
,
Type
:
"model"
,
DisplayName
:
id
,
})
}
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SoraModelFilters
.
HidePromptEnhance
{
filtered
:=
models
[
:
0
]
for
_
,
model
:=
range
models
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
model
.
ID
),
"prompt-enhance"
)
{
continue
}
filtered
=
append
(
filtered
,
model
)
}
models
=
filtered
}
return
models
}
backend/internal/service/sora_quota_service.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraQuotaService 管理 Sora 用户存储配额。
// 配额优先级:用户级 → 分组级 → 系统默认值。
type
SoraQuotaService
struct
{
userRepo
UserRepository
groupRepo
GroupRepository
settingService
*
SettingService
}
// NewSoraQuotaService 创建配额服务实例。
func
NewSoraQuotaService
(
userRepo
UserRepository
,
groupRepo
GroupRepository
,
settingService
*
SettingService
,
)
*
SoraQuotaService
{
return
&
SoraQuotaService
{
userRepo
:
userRepo
,
groupRepo
:
groupRepo
,
settingService
:
settingService
,
}
}
// QuotaInfo 返回给客户端的配额信息。
type
QuotaInfo
struct
{
QuotaBytes
int64
`json:"quota_bytes"`
// 总配额(0 表示无限制)
UsedBytes
int64
`json:"used_bytes"`
// 已使用
AvailableBytes
int64
`json:"available_bytes"`
// 剩余可用(无限制时为 0)
QuotaSource
string
`json:"quota_source"`
// 配额来源:user / group / system / unlimited
Source
string
`json:"source,omitempty"`
// 兼容旧字段
}
// ErrSoraStorageQuotaExceeded 表示配额不足。
var
ErrSoraStorageQuotaExceeded
=
errors
.
New
(
"sora storage quota exceeded"
)
// QuotaExceededError 包含配额不足的上下文信息。
type
QuotaExceededError
struct
{
QuotaBytes
int64
UsedBytes
int64
}
func
(
e
*
QuotaExceededError
)
Error
()
string
{
if
e
==
nil
{
return
"存储配额不足"
}
return
fmt
.
Sprintf
(
"存储配额不足(已用 %d / 配额 %d 字节)"
,
e
.
UsedBytes
,
e
.
QuotaBytes
)
}
type
soraQuotaAtomicUserRepository
interface
{
AddSoraStorageUsageWithQuota
(
ctx
context
.
Context
,
userID
int64
,
deltaBytes
int64
,
effectiveQuota
int64
)
(
int64
,
error
)
ReleaseSoraStorageUsageAtomic
(
ctx
context
.
Context
,
userID
int64
,
deltaBytes
int64
)
(
int64
,
error
)
}
// GetQuota 获取用户的存储配额信息。
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
func
(
s
*
SoraQuotaService
)
GetQuota
(
ctx
context
.
Context
,
userID
int64
)
(
*
QuotaInfo
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
info
:=
&
QuotaInfo
{
UsedBytes
:
user
.
SoraStorageUsedBytes
,
}
// 1. 用户级配额
if
user
.
SoraStorageQuotaBytes
>
0
{
info
.
QuotaBytes
=
user
.
SoraStorageQuotaBytes
info
.
QuotaSource
=
"user"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
calcAvailableBytes
(
info
.
QuotaBytes
,
info
.
UsedBytes
)
return
info
,
nil
}
// 2. 分组级配额(取用户可用分组中最大的配额)
if
len
(
user
.
AllowedGroups
)
>
0
{
var
maxGroupQuota
int64
for
_
,
gid
:=
range
user
.
AllowedGroups
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
gid
)
if
err
!=
nil
{
continue
}
if
group
.
SoraStorageQuotaBytes
>
maxGroupQuota
{
maxGroupQuota
=
group
.
SoraStorageQuotaBytes
}
}
if
maxGroupQuota
>
0
{
info
.
QuotaBytes
=
maxGroupQuota
info
.
QuotaSource
=
"group"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
calcAvailableBytes
(
info
.
QuotaBytes
,
info
.
UsedBytes
)
return
info
,
nil
}
}
// 3. 系统默认值
defaultQuota
:=
s
.
getSystemDefaultQuota
(
ctx
)
if
defaultQuota
>
0
{
info
.
QuotaBytes
=
defaultQuota
info
.
QuotaSource
=
"system"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
calcAvailableBytes
(
info
.
QuotaBytes
,
info
.
UsedBytes
)
return
info
,
nil
}
// 无配额限制
info
.
QuotaSource
=
"unlimited"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
0
return
info
,
nil
}
// CheckQuota 检查用户是否有足够的存储配额。
// 返回 nil 表示配额充足或无限制。
func
(
s
*
SoraQuotaService
)
CheckQuota
(
ctx
context
.
Context
,
userID
int64
,
additionalBytes
int64
)
error
{
quota
,
err
:=
s
.
GetQuota
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
// 0 表示无限制
if
quota
.
QuotaBytes
==
0
{
return
nil
}
if
quota
.
UsedBytes
+
additionalBytes
>
quota
.
QuotaBytes
{
return
&
QuotaExceededError
{
QuotaBytes
:
quota
.
QuotaBytes
,
UsedBytes
:
quota
.
UsedBytes
,
}
}
return
nil
}
// AddUsage 原子累加用量(上传成功后调用)。
func
(
s
*
SoraQuotaService
)
AddUsage
(
ctx
context
.
Context
,
userID
int64
,
bytes
int64
)
error
{
if
bytes
<=
0
{
return
nil
}
quota
,
err
:=
s
.
GetQuota
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
if
quota
.
QuotaBytes
>
0
&&
quota
.
UsedBytes
+
bytes
>
quota
.
QuotaBytes
{
return
&
QuotaExceededError
{
QuotaBytes
:
quota
.
QuotaBytes
,
UsedBytes
:
quota
.
UsedBytes
,
}
}
if
repo
,
ok
:=
s
.
userRepo
.
(
soraQuotaAtomicUserRepository
);
ok
{
newUsed
,
err
:=
repo
.
AddSoraStorageUsageWithQuota
(
ctx
,
userID
,
bytes
,
quota
.
QuotaBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSoraStorageQuotaExceeded
)
{
return
&
QuotaExceededError
{
QuotaBytes
:
quota
.
QuotaBytes
,
UsedBytes
:
quota
.
UsedBytes
,
}
}
return
fmt
.
Errorf
(
"update user quota usage (atomic): %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 累加用量 user=%d +%d total=%d"
,
userID
,
bytes
,
newUsed
)
return
nil
}
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user for quota update: %w"
,
err
)
}
user
.
SoraStorageUsedBytes
+=
bytes
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update user quota usage: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 累加用量 user=%d +%d total=%d"
,
userID
,
bytes
,
user
.
SoraStorageUsedBytes
)
return
nil
}
// ReleaseUsage 释放用量(删除文件后调用)。
func
(
s
*
SoraQuotaService
)
ReleaseUsage
(
ctx
context
.
Context
,
userID
int64
,
bytes
int64
)
error
{
if
bytes
<=
0
{
return
nil
}
if
repo
,
ok
:=
s
.
userRepo
.
(
soraQuotaAtomicUserRepository
);
ok
{
newUsed
,
err
:=
repo
.
ReleaseSoraStorageUsageAtomic
(
ctx
,
userID
,
bytes
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"update user quota release (atomic): %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 释放用量 user=%d -%d total=%d"
,
userID
,
bytes
,
newUsed
)
return
nil
}
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user for quota release: %w"
,
err
)
}
user
.
SoraStorageUsedBytes
-=
bytes
if
user
.
SoraStorageUsedBytes
<
0
{
user
.
SoraStorageUsedBytes
=
0
}
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update user quota release: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 释放用量 user=%d -%d total=%d"
,
userID
,
bytes
,
user
.
SoraStorageUsedBytes
)
return
nil
}
func
calcAvailableBytes
(
quotaBytes
,
usedBytes
int64
)
int64
{
if
quotaBytes
<=
0
{
return
0
}
if
usedBytes
>=
quotaBytes
{
return
0
}
return
quotaBytes
-
usedBytes
}
func
(
s
*
SoraQuotaService
)
getSystemDefaultQuota
(
ctx
context
.
Context
)
int64
{
if
s
.
settingService
==
nil
{
return
0
}
settings
,
err
:=
s
.
settingService
.
GetSoraS3Settings
(
ctx
)
if
err
!=
nil
{
return
0
}
return
settings
.
DefaultStorageQuotaBytes
}
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
func
(
s
*
SoraQuotaService
)
GetQuotaFromSettings
(
ctx
context
.
Context
)
int64
{
return
s
.
getSystemDefaultQuota
(
ctx
)
}
// SetUserQuota 设置用户级配额(管理员操作)。
func
SetUserSoraQuota
(
ctx
context
.
Context
,
userRepo
UserRepository
,
userID
int64
,
quotaBytes
int64
)
error
{
user
,
err
:=
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
user
.
SoraStorageQuotaBytes
=
quotaBytes
return
userRepo
.
Update
(
ctx
,
user
)
}
// ParseQuotaBytes 解析配额字符串为字节数。
func
ParseQuotaBytes
(
s
string
)
int64
{
v
,
_
:=
strconv
.
ParseInt
(
s
,
10
,
64
)
return
v
}
backend/internal/service/sora_quota_service_test.go
deleted
100644 → 0
View file @
7b83d6e7
//go:build unit
package
service
import
(
"context"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ==================== Stub: GroupRepository (用于 SoraQuotaService) ====================
var
_
GroupRepository
=
(
*
stubGroupRepoForQuota
)(
nil
)
type
stubGroupRepoForQuota
struct
{
groups
map
[
int64
]
*
Group
}
func
newStubGroupRepoForQuota
()
*
stubGroupRepoForQuota
{
return
&
stubGroupRepoForQuota
{
groups
:
make
(
map
[
int64
]
*
Group
)}
}
func
(
r
*
stubGroupRepoForQuota
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
r
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
fmt
.
Errorf
(
"group not found"
)
}
func
(
r
*
stubGroupRepoForQuota
)
Create
(
context
.
Context
,
*
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
r
.
GetByID
(
context
.
Background
(),
id
)
}
func
(
r
*
stubGroupRepoForQuota
)
Update
(
context
.
Context
,
*
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
DeleteCascade
(
context
.
Context
,
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
string
,
string
,
string
,
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ListActive
(
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ListActiveByPlatform
(
context
.
Context
,
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ExistsByName
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
GetAccountCount
(
context
.
Context
,
int64
)
(
int64
,
int64
,
error
)
{
return
0
,
0
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
DeleteAccountGroupsByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
GetAccountIDsByGroupIDs
(
context
.
Context
,
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
BindAccountsToGroup
(
context
.
Context
,
int64
,
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
UpdateSortOrders
(
context
.
Context
,
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
// ==================== Stub: SettingRepository (用于 SettingService) ====================
var
_
SettingRepository
=
(
*
stubSettingRepoForQuota
)(
nil
)
type
stubSettingRepoForQuota
struct
{
values
map
[
string
]
string
}
func
newStubSettingRepoForQuota
(
values
map
[
string
]
string
)
*
stubSettingRepoForQuota
{
if
values
==
nil
{
values
=
make
(
map
[
string
]
string
)
}
return
&
stubSettingRepoForQuota
{
values
:
values
}
}
func
(
r
*
stubSettingRepoForQuota
)
Get
(
_
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
&
Setting
{
Key
:
key
,
Value
:
v
},
nil
}
return
nil
,
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForQuota
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
v
,
nil
}
return
""
,
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForQuota
)
Set
(
_
context
.
Context
,
key
,
value
string
)
error
{
r
.
values
[
key
]
=
value
return
nil
}
func
(
r
*
stubSettingRepoForQuota
)
GetMultiple
(
_
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
result
:=
make
(
map
[
string
]
string
)
for
_
,
k
:=
range
keys
{
if
v
,
ok
:=
r
.
values
[
k
];
ok
{
result
[
k
]
=
v
}
}
return
result
,
nil
}
func
(
r
*
stubSettingRepoForQuota
)
SetMultiple
(
_
context
.
Context
,
settings
map
[
string
]
string
)
error
{
for
k
,
v
:=
range
settings
{
r
.
values
[
k
]
=
v
}
return
nil
}
func
(
r
*
stubSettingRepoForQuota
)
GetAll
(
_
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
return
r
.
values
,
nil
}
func
(
r
*
stubSettingRepoForQuota
)
Delete
(
_
context
.
Context
,
key
string
)
error
{
delete
(
r
.
values
,
key
)
return
nil
}
// ==================== GetQuota ====================
func
TestGetQuota_UserLevel
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
// 10MB
SoraStorageUsedBytes
:
3
*
1024
*
1024
,
// 3MB
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
*
1024
*
1024
),
quota
.
QuotaBytes
)
require
.
Equal
(
t
,
int64
(
3
*
1024
*
1024
),
quota
.
UsedBytes
)
require
.
Equal
(
t
,
"user"
,
quota
.
Source
)
}
func
TestGetQuota_GroupLevel
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
// 用户级无配额
SoraStorageUsedBytes
:
1024
,
AllowedGroups
:
[]
int64
{
10
,
20
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
5
*
1024
*
1024
}
groupRepo
.
groups
[
20
]
=
&
Group
{
ID
:
20
,
SoraStorageQuotaBytes
:
20
*
1024
*
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
20
*
1024
*
1024
),
quota
.
QuotaBytes
)
// 取最大值
require
.
Equal
(
t
,
"group"
,
quota
.
Source
)
}
func
TestGetQuota_SystemLevel
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
SoraStorageUsedBytes
:
512
}
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"104857600"
,
// 100MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
settingService
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
104857600
),
quota
.
QuotaBytes
)
require
.
Equal
(
t
,
"system"
,
quota
.
Source
)
}
func
TestGetQuota_NoLimit
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
SoraStorageUsedBytes
:
0
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
quota
.
QuotaBytes
)
require
.
Equal
(
t
,
"unlimited"
,
quota
.
Source
)
}
func
TestGetQuota_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
_
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
999
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"get user"
)
}
func
TestGetQuota_GroupRepoError
(
t
*
testing
.
T
)
{
// 分组获取失败时跳过该分组(不影响整体)
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
AllowedGroups
:
[]
int64
{
999
},
// 不存在的分组
}
groupRepo
:=
newStubGroupRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"unlimited"
,
quota
.
Source
)
// 分组获取失败,回退到无限制
}
// ==================== CheckQuota ====================
func
TestCheckQuota_Sufficient
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
3
*
1024
*
1024
,
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
CheckQuota
(
context
.
Background
(),
1
,
1024
)
require
.
NoError
(
t
,
err
)
}
func
TestCheckQuota_Exceeded
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
10
*
1024
*
1024
,
// 已满
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
CheckQuota
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"配额不足"
)
}
func
TestCheckQuota_NoLimit
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
// 无限制
SoraStorageUsedBytes
:
1000000000
,
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
CheckQuota
(
context
.
Background
(),
1
,
999999999
)
require
.
NoError
(
t
,
err
)
// 无限制时始终通过
}
func
TestCheckQuota_ExactBoundary
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
1024
,
SoraStorageUsedBytes
:
1024
,
// 恰好满
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
// 额外 0 字节不超
require
.
NoError
(
t
,
svc
.
CheckQuota
(
context
.
Background
(),
1
,
0
))
// 额外 1 字节超出
require
.
Error
(
t
,
svc
.
CheckQuota
(
context
.
Background
(),
1
,
1
))
}
// ==================== AddUsage ====================
func
TestAddUsage_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
2048
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3072
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestAddUsage_ZeroBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
0
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestAddUsage_NegativeBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
-
100
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestAddUsage_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
999
,
1024
)
require
.
Error
(
t
,
err
)
}
func
TestAddUsage_UpdateError
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
0
}
userRepo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
1024
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"update user quota usage"
)
}
// ==================== ReleaseUsage ====================
func
TestReleaseUsage_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
3072
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2048
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestReleaseUsage_ClampToZero
(
t
*
testing
.
T
)
{
// 释放量大于已用量时,应 clamp 到 0
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
500
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
1000
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestReleaseUsage_ZeroBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
0
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestReleaseUsage_NegativeBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
-
50
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestReleaseUsage_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
999
,
1024
)
require
.
Error
(
t
,
err
)
}
func
TestReleaseUsage_UpdateError
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
userRepo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
512
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"update user quota release"
)
}
// ==================== GetQuotaFromSettings ====================
func
TestGetQuotaFromSettings_NilSettingService
(
t
*
testing
.
T
)
{
svc
:=
NewSoraQuotaService
(
nil
,
nil
,
nil
)
require
.
Equal
(
t
,
int64
(
0
),
svc
.
GetQuotaFromSettings
(
context
.
Background
()))
}
func
TestGetQuotaFromSettings_WithSettings
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"52428800"
,
// 50MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
nil
,
nil
,
settingService
)
require
.
Equal
(
t
,
int64
(
52428800
),
svc
.
GetQuotaFromSettings
(
context
.
Background
()))
}
// ==================== SetUserSoraQuota ====================
func
TestSetUserSoraQuota_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
}
err
:=
SetUserSoraQuota
(
context
.
Background
(),
userRepo
,
1
,
10
*
1024
*
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
*
1024
*
1024
),
userRepo
.
users
[
1
]
.
SoraStorageQuotaBytes
)
}
func
TestSetUserSoraQuota_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
err
:=
SetUserSoraQuota
(
context
.
Background
(),
userRepo
,
999
,
1024
)
require
.
Error
(
t
,
err
)
}
// ==================== ParseQuotaBytes ====================
func
TestParseQuotaBytes
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
int64
(
1048576
),
ParseQuotaBytes
(
"1048576"
))
require
.
Equal
(
t
,
int64
(
0
),
ParseQuotaBytes
(
""
))
require
.
Equal
(
t
,
int64
(
0
),
ParseQuotaBytes
(
"abc"
))
require
.
Equal
(
t
,
int64
(
-
1
),
ParseQuotaBytes
(
"-1"
))
}
// ==================== 优先级完整测试 ====================
func
TestQuotaPriority_UserOverridesGroup
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
5
*
1024
*
1024
,
AllowedGroups
:
[]
int64
{
10
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
20
*
1024
*
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"user"
,
quota
.
Source
)
// 用户级优先
require
.
Equal
(
t
,
int64
(
5
*
1024
*
1024
),
quota
.
QuotaBytes
)
}
func
TestQuotaPriority_GroupOverridesSystem
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
AllowedGroups
:
[]
int64
{
10
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
20
*
1024
*
1024
}
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"104857600"
,
// 100MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
settingService
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"group"
,
quota
.
Source
)
// 分组级优先于系统
require
.
Equal
(
t
,
int64
(
20
*
1024
*
1024
),
quota
.
QuotaBytes
)
}
func
TestQuotaPriority_FallbackToSystem
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
AllowedGroups
:
[]
int64
{
10
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
0
}
// 分组无配额
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"52428800"
,
// 50MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
settingService
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"system"
,
quota
.
Source
)
require
.
Equal
(
t
,
int64
(
52428800
),
quota
.
QuotaBytes
)
}
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment