Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
62e80c60
Commit
62e80c60
authored
Apr 05, 2026
by
erio
Browse files
revert: completely remove all Sora functionality
parent
dbb248df
Changes
136
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/openai_oauth_service_auth_url_test.go
View file @
62e80c60
...
...
@@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
ClientID
,
session
.
ClientID
)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
func
TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient
(
t
*
testing
.
T
)
{
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientAuthURLStub
{})
defer
svc
.
Stop
()
result
,
err
:=
svc
.
GenerateAuthURL
(
context
.
Background
(),
nil
,
""
,
PlatformSora
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
result
.
AuthURL
)
require
.
NotEmpty
(
t
,
result
.
SessionID
)
parsed
,
err
:=
url
.
Parse
(
result
.
AuthURL
)
require
.
NoError
(
t
,
err
)
q
:=
parsed
.
Query
()
require
.
Equal
(
t
,
openai
.
ClientID
,
q
.
Get
(
"client_id"
))
require
.
Empty
(
t
,
q
.
Get
(
"codex_cli_simplified_flow"
))
session
,
ok
:=
svc
.
sessionStore
.
Get
(
result
.
SessionID
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
ClientID
,
session
.
ClientID
)
}
backend/internal/service/openai_oauth_service_sora_session_test.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type
openaiOAuthClientNoopStub
struct
{}
func
(
s
*
openaiOAuthClientNoopStub
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientNoopStub
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientNoopStub
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_Success
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=st-token"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
"st-token"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
info
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
require
.
Equal
(
t
,
"demo@example.com"
,
info
.
Email
)
require
.
Greater
(
t
,
info
.
ExpiresAt
,
int64
(
0
))
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"expires":"2099-01-01T00:00:00Z"}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
_
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
"st-token"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"missing access token"
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=st-cookie-value"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
"__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=chunk-0chunk-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=new-0new-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=ok-0ok-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"set-cookie"
,
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/"
,
"set-cookie"
,
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/"
,
"set-cookie"
,
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
backend/internal/service/openai_token_provider.go
View file @
62e80c60
...
...
@@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
// OpenAITokenCache token cache interface.
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider manages access_token for OpenAI
/Sora
OAuth accounts.
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
...
...
@@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
(
account
.
Platform
!=
PlatformOpenAI
&&
account
.
Platform
!=
PlatformSora
)
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai
/sora
oauth account"
)
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
...
...
@@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
p
.
metrics
.
refreshRequests
.
Add
(
1
)
p
.
metrics
.
touchNow
()
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if
account
.
Platform
==
PlatformSora
{
slog
.
Debug
(
"openai_token_refresh_skipped_for_sora"
,
"account_id"
,
account
.
ID
)
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
{
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
...
...
backend/internal/service/openai_token_provider_test.go
View file @
62e80c60
...
...
@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai
/sora
oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
...
...
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai
/sora
oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
...
...
backend/internal/service/setting_service.go
View file @
62e80c60
...
...
@@ -22,8 +22,6 @@ import (
var
(
ErrRegistrationDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrSettingNotFound
=
infraerrors
.
NotFound
(
"SETTING_NOT_FOUND"
,
"setting not found"
)
ErrSoraS3ProfileNotFound
=
infraerrors
.
NotFound
(
"SORA_S3_PROFILE_NOT_FOUND"
,
"sora s3 profile not found"
)
ErrSoraS3ProfileExists
=
infraerrors
.
Conflict
(
"SORA_S3_PROFILE_EXISTS"
,
"sora s3 profile already exists"
)
ErrDefaultSubGroupInvalid
=
infraerrors
.
BadRequest
(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID"
,
"default subscription group must exist and be subscription type"
,
...
...
@@ -104,7 +102,6 @@ type SettingService struct {
defaultSubGroupReader
DefaultSubscriptionGroupReader
cfg
*
config
.
Config
onUpdate
func
()
// Callback when settings are updated (for cache invalidation)
onS3Update
func
()
// Callback when Sora S3 settings are updated
version
string
// Application version
}
...
...
@@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton
,
SettingKeyPurchaseSubscriptionEnabled
,
SettingKeyPurchaseSubscriptionURL
,
SettingKeySoraClientEnabled
,
SettingKeyCustomMenuItems
,
SettingKeyCustomEndpoints
,
SettingKeyLinuxDoConnectEnabled
,
...
...
@@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
...
...
@@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s
.
onUpdate
=
callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func
(
s
*
SettingService
)
SetOnS3UpdateCallback
(
callback
func
())
{
s
.
onS3Update
=
callback
}
// SetVersion sets the application version for injection into public settings
func
(
s
*
SettingService
)
SetVersion
(
version
string
)
{
s
.
version
=
version
...
...
@@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url,omitempty"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
CustomMenuItems
json
.
RawMessage
`json:"custom_menu_items"`
CustomEndpoints
json
.
RawMessage
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
...
...
@@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
SoraClientEnabled
:
settings
.
SoraClientEnabled
,
CustomMenuItems
:
filterUserVisibleMenuItems
(
settings
.
CustomMenuItems
),
CustomEndpoints
:
safeRawJSONArray
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
...
...
@@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
updates
[
SettingKeyPurchaseSubscriptionEnabled
]
=
strconv
.
FormatBool
(
settings
.
PurchaseSubscriptionEnabled
)
updates
[
SettingKeyPurchaseSubscriptionURL
]
=
strings
.
TrimSpace
(
settings
.
PurchaseSubscriptionURL
)
updates
[
SettingKeySoraClientEnabled
]
=
strconv
.
FormatBool
(
settings
.
SoraClientEnabled
)
updates
[
SettingKeyCustomMenuItems
]
=
settings
.
CustomMenuItems
updates
[
SettingKeyCustomEndpoints
]
=
settings
.
CustomEndpoints
...
...
@@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeySoraClientEnabled
:
"false"
,
SettingKeyCustomMenuItems
:
"[]"
,
SettingKeyCustomEndpoints
:
"[]"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
...
...
@@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
BackendModeEnabled
:
settings
[
SettingKeyBackendModeEnabled
]
==
"true"
,
...
...
@@ -1584,606 +1569,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyStreamTimeoutSettings
,
string
(
data
))
}
type
soraS3ProfilesStore
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
soraS3ProfileStoreItem
`json:"items"`
}
type
soraS3ProfileStoreItem
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func
(
s
*
SettingService
)
GetSoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
profiles
,
err
:=
s
.
ListSoraS3Profiles
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
activeProfile
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
activeProfile
==
nil
{
return
&
SoraS3Settings
{},
nil
}
return
&
SoraS3Settings
{
Enabled
:
activeProfile
.
Enabled
,
Endpoint
:
activeProfile
.
Endpoint
,
Region
:
activeProfile
.
Region
,
Bucket
:
activeProfile
.
Bucket
,
AccessKeyID
:
activeProfile
.
AccessKeyID
,
SecretAccessKey
:
activeProfile
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
activeProfile
.
SecretAccessKeyConfigured
,
Prefix
:
activeProfile
.
Prefix
,
ForcePathStyle
:
activeProfile
.
ForcePathStyle
,
CDNURL
:
activeProfile
.
CDNURL
,
DefaultStorageQuotaBytes
:
activeProfile
.
DefaultStorageQuotaBytes
,
},
nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func
(
s
*
SettingService
)
SetSoraS3Settings
(
ctx
context
.
Context
,
settings
*
SoraS3Settings
)
error
{
if
settings
==
nil
{
return
fmt
.
Errorf
(
"settings cannot be nil"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
activeIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
store
.
ActiveProfileID
)
if
activeIndex
<
0
{
activeID
:=
"default"
if
hasSoraS3ProfileID
(
store
.
Items
,
activeID
)
{
activeID
=
fmt
.
Sprintf
(
"default-%d"
,
time
.
Now
()
.
Unix
())
}
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
activeID
,
Name
:
"Default"
,
UpdatedAt
:
now
,
})
store
.
ActiveProfileID
=
activeID
activeIndex
=
len
(
store
.
Items
)
-
1
}
active
:=
store
.
Items
[
activeIndex
]
active
.
Enabled
=
settings
.
Enabled
active
.
Endpoint
=
strings
.
TrimSpace
(
settings
.
Endpoint
)
active
.
Region
=
strings
.
TrimSpace
(
settings
.
Region
)
active
.
Bucket
=
strings
.
TrimSpace
(
settings
.
Bucket
)
active
.
AccessKeyID
=
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
active
.
Prefix
=
strings
.
TrimSpace
(
settings
.
Prefix
)
active
.
ForcePathStyle
=
settings
.
ForcePathStyle
active
.
CDNURL
=
strings
.
TrimSpace
(
settings
.
CDNURL
)
active
.
DefaultStorageQuotaBytes
=
maxInt64
(
settings
.
DefaultStorageQuotaBytes
,
0
)
if
settings
.
SecretAccessKey
!=
""
{
active
.
SecretAccessKey
=
settings
.
SecretAccessKey
}
active
.
UpdatedAt
=
now
store
.
Items
[
activeIndex
]
=
active
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func
(
s
*
SettingService
)
ListSoraS3Profiles
(
ctx
context
.
Context
)
(
*
SoraS3ProfileList
,
error
)
{
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
convertSoraS3ProfilesStore
(
store
),
nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func
(
s
*
SettingService
)
CreateSoraS3Profile
(
ctx
context
.
Context
,
profile
*
SoraS3Profile
,
setActive
bool
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
profileID
:=
strings
.
TrimSpace
(
profile
.
ProfileID
)
if
profileID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
if
hasSoraS3ProfileID
(
store
.
Items
,
profileID
)
{
return
nil
,
ErrSoraS3ProfileExists
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
profileID
,
Name
:
name
,
Enabled
:
profile
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
profile
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
profile
.
Region
),
Bucket
:
strings
.
TrimSpace
(
profile
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
profile
.
AccessKeyID
),
SecretAccessKey
:
profile
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
profile
.
Prefix
),
ForcePathStyle
:
profile
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
profile
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
})
if
setActive
||
store
.
ActiveProfileID
==
""
{
store
.
ActiveProfileID
=
profileID
}
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
created
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
profileID
)
if
created
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
created
,
nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func
(
s
*
SettingService
)
UpdateSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
,
profile
*
SoraS3Profile
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
target
:=
store
.
Items
[
targetIndex
]
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
target
.
Name
=
name
target
.
Enabled
=
profile
.
Enabled
target
.
Endpoint
=
strings
.
TrimSpace
(
profile
.
Endpoint
)
target
.
Region
=
strings
.
TrimSpace
(
profile
.
Region
)
target
.
Bucket
=
strings
.
TrimSpace
(
profile
.
Bucket
)
target
.
AccessKeyID
=
strings
.
TrimSpace
(
profile
.
AccessKeyID
)
target
.
Prefix
=
strings
.
TrimSpace
(
profile
.
Prefix
)
target
.
ForcePathStyle
=
profile
.
ForcePathStyle
target
.
CDNURL
=
strings
.
TrimSpace
(
profile
.
CDNURL
)
target
.
DefaultStorageQuotaBytes
=
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
)
if
profile
.
SecretAccessKey
!=
""
{
target
.
SecretAccessKey
=
profile
.
SecretAccessKey
}
target
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
[
targetIndex
]
=
target
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
updated
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
targetID
)
if
updated
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
updated
,
nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func
(
s
*
SettingService
)
DeleteSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
error
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
ErrSoraS3ProfileNotFound
}
store
.
Items
=
append
(
store
.
Items
[
:
targetIndex
],
store
.
Items
[
targetIndex
+
1
:
]
...
)
if
store
.
ActiveProfileID
==
targetID
{
store
.
ActiveProfileID
=
""
if
len
(
store
.
Items
)
>
0
{
store
.
ActiveProfileID
=
store
.
Items
[
0
]
.
ProfileID
}
}
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func
(
s
*
SettingService
)
SetActiveSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
(
*
SoraS3Profile
,
error
)
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
store
.
ActiveProfileID
=
targetID
store
.
Items
[
targetIndex
]
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
active
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
active
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
active
,
nil
}
func
(
s
*
SettingService
)
loadSoraS3ProfilesStore
(
ctx
context
.
Context
)
(
*
soraS3ProfilesStore
,
error
)
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySoraS3Profiles
)
if
err
==
nil
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
&
soraS3ProfilesStore
{},
nil
}
var
store
soraS3ProfilesStore
if
unmarshalErr
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
store
);
unmarshalErr
!=
nil
{
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal sora s3 profiles: %w"
,
unmarshalErr
)
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
normalized
:=
normalizeSoraS3ProfilesStore
(
store
)
return
&
normalized
,
nil
}
if
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"get sora s3 profiles: %w"
,
err
)
}
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
legacyErr
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
func
(
s
*
SettingService
)
persistSoraS3ProfilesStore
(
ctx
context
.
Context
,
store
*
soraS3ProfilesStore
)
error
{
if
store
==
nil
{
return
fmt
.
Errorf
(
"sora s3 profiles store cannot be nil"
)
}
normalized
:=
normalizeSoraS3ProfilesStore
(
*
store
)
data
,
err
:=
json
.
Marshal
(
normalized
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal sora s3 profiles: %w"
,
err
)
}
updates
:=
map
[
string
]
string
{
SettingKeySoraS3Profiles
:
string
(
data
),
}
active
:=
pickActiveSoraS3ProfileFromStore
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
if
active
==
nil
{
updates
[
SettingKeySoraS3Enabled
]
=
"false"
updates
[
SettingKeySoraS3Endpoint
]
=
""
updates
[
SettingKeySoraS3Region
]
=
""
updates
[
SettingKeySoraS3Bucket
]
=
""
updates
[
SettingKeySoraS3AccessKeyID
]
=
""
updates
[
SettingKeySoraS3Prefix
]
=
""
updates
[
SettingKeySoraS3ForcePathStyle
]
=
"false"
updates
[
SettingKeySoraS3CDNURL
]
=
""
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
"0"
updates
[
SettingKeySoraS3SecretAccessKey
]
=
""
}
else
{
updates
[
SettingKeySoraS3Enabled
]
=
strconv
.
FormatBool
(
active
.
Enabled
)
updates
[
SettingKeySoraS3Endpoint
]
=
strings
.
TrimSpace
(
active
.
Endpoint
)
updates
[
SettingKeySoraS3Region
]
=
strings
.
TrimSpace
(
active
.
Region
)
updates
[
SettingKeySoraS3Bucket
]
=
strings
.
TrimSpace
(
active
.
Bucket
)
updates
[
SettingKeySoraS3AccessKeyID
]
=
strings
.
TrimSpace
(
active
.
AccessKeyID
)
updates
[
SettingKeySoraS3Prefix
]
=
strings
.
TrimSpace
(
active
.
Prefix
)
updates
[
SettingKeySoraS3ForcePathStyle
]
=
strconv
.
FormatBool
(
active
.
ForcePathStyle
)
updates
[
SettingKeySoraS3CDNURL
]
=
strings
.
TrimSpace
(
active
.
CDNURL
)
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
strconv
.
FormatInt
(
maxInt64
(
active
.
DefaultStorageQuotaBytes
,
0
),
10
)
updates
[
SettingKeySoraS3SecretAccessKey
]
=
active
.
SecretAccessKey
}
if
err
:=
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
);
err
!=
nil
{
return
err
}
if
s
.
onUpdate
!=
nil
{
s
.
onUpdate
()
}
if
s
.
onS3Update
!=
nil
{
s
.
onS3Update
()
}
return
nil
}
func
(
s
*
SettingService
)
getLegacySoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
keys
:=
[]
string
{
SettingKeySoraS3Enabled
,
SettingKeySoraS3Endpoint
,
SettingKeySoraS3Region
,
SettingKeySoraS3Bucket
,
SettingKeySoraS3AccessKeyID
,
SettingKeySoraS3SecretAccessKey
,
SettingKeySoraS3Prefix
,
SettingKeySoraS3ForcePathStyle
,
SettingKeySoraS3CDNURL
,
SettingKeySoraDefaultStorageQuotaBytes
,
}
values
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get legacy sora s3 settings: %w"
,
err
)
}
result
:=
&
SoraS3Settings
{
Enabled
:
values
[
SettingKeySoraS3Enabled
]
==
"true"
,
Endpoint
:
values
[
SettingKeySoraS3Endpoint
],
Region
:
values
[
SettingKeySoraS3Region
],
Bucket
:
values
[
SettingKeySoraS3Bucket
],
AccessKeyID
:
values
[
SettingKeySoraS3AccessKeyID
],
SecretAccessKey
:
values
[
SettingKeySoraS3SecretAccessKey
],
SecretAccessKeyConfigured
:
values
[
SettingKeySoraS3SecretAccessKey
]
!=
""
,
Prefix
:
values
[
SettingKeySoraS3Prefix
],
ForcePathStyle
:
values
[
SettingKeySoraS3ForcePathStyle
]
==
"true"
,
CDNURL
:
values
[
SettingKeySoraS3CDNURL
],
}
if
v
,
parseErr
:=
strconv
.
ParseInt
(
values
[
SettingKeySoraDefaultStorageQuotaBytes
],
10
,
64
);
parseErr
==
nil
{
result
.
DefaultStorageQuotaBytes
=
v
}
return
result
,
nil
}
func
normalizeSoraS3ProfilesStore
(
store
soraS3ProfilesStore
)
soraS3ProfilesStore
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
store
.
Items
))
normalized
:=
soraS3ProfilesStore
{
ActiveProfileID
:
strings
.
TrimSpace
(
store
.
ActiveProfileID
),
Items
:
make
([]
soraS3ProfileStoreItem
,
0
,
len
(
store
.
Items
)),
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
item
.
ProfileID
=
strings
.
TrimSpace
(
item
.
ProfileID
)
if
item
.
ProfileID
==
""
{
item
.
ProfileID
=
fmt
.
Sprintf
(
"profile-%d"
,
idx
+
1
)
}
if
_
,
exists
:=
seen
[
item
.
ProfileID
];
exists
{
continue
}
seen
[
item
.
ProfileID
]
=
struct
{}{}
item
.
Name
=
strings
.
TrimSpace
(
item
.
Name
)
if
item
.
Name
==
""
{
item
.
Name
=
item
.
ProfileID
}
item
.
Endpoint
=
strings
.
TrimSpace
(
item
.
Endpoint
)
item
.
Region
=
strings
.
TrimSpace
(
item
.
Region
)
item
.
Bucket
=
strings
.
TrimSpace
(
item
.
Bucket
)
item
.
AccessKeyID
=
strings
.
TrimSpace
(
item
.
AccessKeyID
)
item
.
Prefix
=
strings
.
TrimSpace
(
item
.
Prefix
)
item
.
CDNURL
=
strings
.
TrimSpace
(
item
.
CDNURL
)
item
.
DefaultStorageQuotaBytes
=
maxInt64
(
item
.
DefaultStorageQuotaBytes
,
0
)
item
.
UpdatedAt
=
strings
.
TrimSpace
(
item
.
UpdatedAt
)
if
item
.
UpdatedAt
==
""
{
item
.
UpdatedAt
=
now
}
normalized
.
Items
=
append
(
normalized
.
Items
,
item
)
}
if
len
(
normalized
.
Items
)
==
0
{
normalized
.
ActiveProfileID
=
""
return
normalized
}
if
findSoraS3ProfileIndex
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
>=
0
{
return
normalized
}
normalized
.
ActiveProfileID
=
normalized
.
Items
[
0
]
.
ProfileID
return
normalized
}
func
convertSoraS3ProfilesStore
(
store
*
soraS3ProfilesStore
)
*
SoraS3ProfileList
{
if
store
==
nil
{
return
&
SoraS3ProfileList
{}
}
items
:=
make
([]
SoraS3Profile
,
0
,
len
(
store
.
Items
))
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
items
=
append
(
items
,
SoraS3Profile
{
ProfileID
:
item
.
ProfileID
,
Name
:
item
.
Name
,
IsActive
:
item
.
ProfileID
==
store
.
ActiveProfileID
,
Enabled
:
item
.
Enabled
,
Endpoint
:
item
.
Endpoint
,
Region
:
item
.
Region
,
Bucket
:
item
.
Bucket
,
AccessKeyID
:
item
.
AccessKeyID
,
SecretAccessKey
:
item
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
item
.
SecretAccessKey
!=
""
,
Prefix
:
item
.
Prefix
,
ForcePathStyle
:
item
.
ForcePathStyle
,
CDNURL
:
item
.
CDNURL
,
DefaultStorageQuotaBytes
:
item
.
DefaultStorageQuotaBytes
,
UpdatedAt
:
item
.
UpdatedAt
,
})
}
return
&
SoraS3ProfileList
{
ActiveProfileID
:
store
.
ActiveProfileID
,
Items
:
items
,
}
}
func
pickActiveSoraS3Profile
(
items
[]
SoraS3Profile
,
activeProfileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileByID
(
items
[]
SoraS3Profile
,
profileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
&
items
[
idx
]
}
}
return
nil
}
func
pickActiveSoraS3ProfileFromStore
(
items
[]
soraS3ProfileStoreItem
,
activeProfileID
string
)
*
soraS3ProfileStoreItem
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileIndex
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
int
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
idx
}
}
return
-
1
}
func
hasSoraS3ProfileID
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
bool
{
return
findSoraS3ProfileIndex
(
items
,
profileID
)
>=
0
}
func
isEmptyLegacySoraS3Settings
(
settings
*
SoraS3Settings
)
bool
{
if
settings
==
nil
{
return
true
}
if
settings
.
Enabled
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Endpoint
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Region
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Bucket
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
!=
""
{
return
false
}
if
settings
.
SecretAccessKey
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Prefix
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
CDNURL
)
!=
""
{
return
false
}
return
settings
.
DefaultStorageQuotaBytes
==
0
}
func
maxInt64
(
value
int64
,
min
int64
)
int64
{
if
value
<
min
{
return
min
}
return
value
}
backend/internal/service/settings_view.go
View file @
62e80c60
...
...
@@ -41,7 +41,6 @@ type SystemSettings struct {
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
CustomMenuItems
string
// JSON array of custom menu items
CustomEndpoints
string
// JSON array of custom endpoints
...
...
@@ -107,7 +106,6 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
CustomMenuItems
string
// JSON array of custom menu items
CustomEndpoints
string
// JSON array of custom endpoints
...
...
@@ -116,46 +114,6 @@ type PublicSettings struct {
Version
string
}
// SoraS3Settings Sora S3 存储配置
type
SoraS3Settings
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type
SoraS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"-"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type
SoraS3ProfileList
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
SoraS3Profile
`json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type
StreamTimeoutSettings
struct
{
// Enabled 是否启用流超时处理
...
...
backend/internal/service/sora_account_service.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
"context"
// SoraAccountRepository Sora 账号扩展表仓储接口
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
//
// 设计说明:
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
// - Token 刷新时需要同时更新两个表以保持数据一致性
type
SoraAccountRepository
interface
{
// Upsert 创建或更新 Sora 账号扩展信息
// accountID: 关联的 accounts.id
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
//
// 如果记录不存在则创建,存在则更新。
// 用于:
// 1. 创建 Sora 账号时初始化扩展表
// 2. Token 刷新时同步更新扩展表
Upsert
(
ctx
context
.
Context
,
accountID
int64
,
updates
map
[
string
]
any
)
error
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
// 返回 nil, nil 表示记录不存在(非错误)
GetByAccountID
(
ctx
context
.
Context
,
accountID
int64
)
(
*
SoraAccount
,
error
)
// Delete 删除 Sora 账号扩展信息
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
Delete
(
ctx
context
.
Context
,
accountID
int64
)
error
}
// SoraAccount Sora 账号扩展信息
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
type
SoraAccount
struct
{
AccountID
int64
// 关联的 accounts.id
AccessToken
string
// OAuth access_token
RefreshToken
string
// OAuth refresh_token
SessionToken
string
// Session token(可选,用于 ST→AT 兜底)
}
backend/internal/service/sora_client.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"context"
"fmt"
"net/http"
)
// SoraClient 定义直连 Sora 的任务操作接口。
type
SoraClient
interface
{
Enabled
()
bool
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
}
// SoraImageRequest 图片生成请求参数
type
SoraImageRequest
struct
{
Prompt
string
Width
int
Height
int
MediaID
string
}
// SoraVideoRequest 视频生成请求参数
type
SoraVideoRequest
struct
{
Prompt
string
Orientation
string
Frames
int
Model
string
Size
string
VideoCount
int
MediaID
string
RemixTargetID
string
CameoIDs
[]
string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type
SoraStoryboardRequest
struct
{
Prompt
string
Orientation
string
Frames
int
Model
string
Size
string
MediaID
string
}
// SoraImageTaskStatus 图片任务状态
type
SoraImageTaskStatus
struct
{
ID
string
Status
string
ProgressPct
float64
URLs
[]
string
ErrorMsg
string
}
// SoraVideoTaskStatus 视频任务状态
type
SoraVideoTaskStatus
struct
{
ID
string
Status
string
ProgressPct
int
URLs
[]
string
GenerationID
string
ErrorMsg
string
}
// SoraCameoStatus 角色处理中间态
type
SoraCameoStatus
struct
{
Status
string
StatusMessage
string
DisplayNameHint
string
UsernameHint
string
ProfileAssetURL
string
InstructionSetHint
any
InstructionSet
any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type
SoraCharacterFinalizeRequest
struct
{
CameoID
string
Username
string
DisplayName
string
ProfileAssetPointer
string
InstructionSet
any
}
// SoraUpstreamError 上游错误
type
SoraUpstreamError
struct
{
StatusCode
int
Message
string
Headers
http
.
Header
Body
[]
byte
}
func
(
e
*
SoraUpstreamError
)
Error
()
string
{
if
e
==
nil
{
return
"sora upstream error"
}
if
e
.
Message
!=
""
{
return
fmt
.
Sprintf
(
"sora upstream error: %d %s"
,
e
.
StatusCode
,
e
.
Message
)
}
return
fmt
.
Sprintf
(
"sora upstream error: %d"
,
e
.
StatusCode
)
}
backend/internal/service/sora_gateway_service.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
"math/rand"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
const
soraImageInputMaxBytes
=
20
<<
20
const
soraImageInputMaxRedirects
=
3
const
soraImageInputTimeout
=
20
*
time
.
Second
const
soraVideoInputMaxBytes
=
200
<<
20
const
soraVideoInputMaxRedirects
=
3
const
soraVideoInputTimeout
=
60
*
time
.
Second
var
soraImageSizeMap
=
map
[
string
]
string
{
"gpt-image"
:
"360"
,
"gpt-image-landscape"
:
"540"
,
"gpt-image-portrait"
:
"540"
,
}
var
soraBlockedHostnames
=
map
[
string
]
struct
{}{
"localhost"
:
{},
"localhost.localdomain"
:
{},
"metadata.google.internal"
:
{},
"metadata.google.internal."
:
{},
}
var
soraBlockedCIDRs
=
mustParseCIDRs
([]
string
{
"0.0.0.0/8"
,
"10.0.0.0/8"
,
"100.64.0.0/10"
,
"127.0.0.0/8"
,
"169.254.0.0/16"
,
"172.16.0.0/12"
,
"192.168.0.0/16"
,
"224.0.0.0/4"
,
"240.0.0.0/4"
,
"::/128"
,
"::1/128"
,
"fc00::/7"
,
"fe80::/10"
,
})
// SoraGatewayService handles forwarding requests to Sora upstream.
type
SoraGatewayService
struct
{
soraClient
SoraClient
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
// 用于 apikey 类型账号的 HTTP 透传
cfg
*
config
.
Config
}
type
soraWatermarkOptions
struct
{
Enabled
bool
ParseMethod
string
ParseURL
string
ParseToken
string
FallbackOnFailure
bool
DeletePost
bool
}
type
soraCharacterOptions
struct
{
SetPublic
bool
DeleteAfterGenerate
bool
}
type
soraCharacterFlowResult
struct
{
CameoID
string
CharacterID
string
Username
string
DisplayName
string
}
var
soraStoryboardPattern
=
regexp
.
MustCompile
(
`\[\d+(?:\.\d+)?s\]`
)
var
soraStoryboardShotPattern
=
regexp
.
MustCompile
(
`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`
)
var
soraRemixTargetPattern
=
regexp
.
MustCompile
(
`s_[a-f0-9]{32}`
)
var
soraRemixTargetInURLPattern
=
regexp
.
MustCompile
(
`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`
)
type
soraPreflightChecker
interface
{
PreflightCheck
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
,
modelCfg
SoraModelConfig
)
error
}
func
NewSoraGatewayService
(
soraClient
SoraClient
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
cfg
*
config
.
Config
,
)
*
SoraGatewayService
{
return
&
SoraGatewayService
{
soraClient
:
soraClient
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
cfg
:
cfg
,
}
}
func
(
s
*
SoraGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
clientStream
bool
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
if
account
.
Type
==
AccountTypeAPIKey
&&
account
.
GetBaseURL
()
!=
""
{
if
s
.
httpUpstream
==
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"HTTP upstream client not configured"
,
clientStream
)
return
nil
,
errors
.
New
(
"httpUpstream not configured for sora apikey forwarding"
)
}
return
s
.
forwardToUpstream
(
ctx
,
c
,
account
,
body
,
clientStream
,
startTime
)
}
if
s
.
soraClient
==
nil
||
!
s
.
soraClient
.
Enabled
()
{
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 上游未配置"
,
},
})
}
return
nil
,
errors
.
New
(
"sora upstream not configured"
)
}
var
reqBody
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
reqBody
);
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"parse request: %w"
,
err
)
}
reqModel
,
_
:=
reqBody
[
"model"
]
.
(
string
)
reqStream
,
_
:=
reqBody
[
"stream"
]
.
(
bool
)
if
strings
.
TrimSpace
(
reqModel
)
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"model is required"
)
}
originalModel
:=
reqModel
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
var
upstreamModel
string
if
mappedModel
!=
""
&&
mappedModel
!=
reqModel
{
reqModel
=
mappedModel
upstreamModel
=
mappedModel
}
modelCfg
,
ok
:=
GetSoraModelConfig
(
reqModel
)
if
!
ok
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Unsupported Sora model"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"unsupported model: %s"
,
reqModel
)
}
prompt
,
imageInput
,
videoInput
,
remixTargetID
:=
extractSoraInput
(
reqBody
)
prompt
=
strings
.
TrimSpace
(
prompt
)
imageInput
=
strings
.
TrimSpace
(
imageInput
)
videoInput
=
strings
.
TrimSpace
(
videoInput
)
remixTargetID
=
strings
.
TrimSpace
(
remixTargetID
)
if
videoInput
!=
""
&&
modelCfg
.
Type
!=
"video"
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"video input only supports video models"
,
clientStream
)
return
nil
,
errors
.
New
(
"video input only supports video models"
)
}
if
videoInput
!=
""
&&
imageInput
!=
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"image input and video input cannot be used together"
,
clientStream
)
return
nil
,
errors
.
New
(
"image input and video input cannot be used together"
)
}
characterOnly
:=
videoInput
!=
""
&&
prompt
==
""
if
modelCfg
.
Type
==
"prompt_enhance"
&&
prompt
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"prompt is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"prompt is required"
)
}
if
modelCfg
.
Type
!=
"prompt_enhance"
&&
prompt
==
""
&&
!
characterOnly
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"prompt is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"prompt is required"
)
}
reqCtx
,
cancel
:=
s
.
withSoraTimeout
(
ctx
,
reqStream
)
if
cancel
!=
nil
{
defer
cancel
()
}
if
checker
,
ok
:=
s
.
soraClient
.
(
soraPreflightChecker
);
ok
&&
!
characterOnly
{
if
err
:=
checker
.
PreflightCheck
(
reqCtx
,
account
,
reqModel
,
modelCfg
);
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
}
if
modelCfg
.
Type
==
"prompt_enhance"
{
enhancedPrompt
,
err
:=
s
.
soraClient
.
EnhancePrompt
(
reqCtx
,
account
,
prompt
,
modelCfg
.
ExpansionLevel
,
modelCfg
.
DurationS
)
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
content
:=
strings
.
TrimSpace
(
enhancedPrompt
)
if
content
==
""
{
content
=
prompt
}
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusOK
,
buildSoraNonStreamResponse
(
content
,
reqModel
))
}
return
&
ForwardResult
{
RequestID
:
""
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
"prompt"
,
},
nil
}
characterOpts
:=
parseSoraCharacterOptions
(
reqBody
)
watermarkOpts
:=
parseSoraWatermarkOptions
(
reqBody
)
var
characterResult
*
soraCharacterFlowResult
if
videoInput
!=
""
{
videoData
,
videoErr
:=
decodeSoraVideoInput
(
reqCtx
,
videoInput
)
if
videoErr
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
videoErr
.
Error
(),
clientStream
)
return
nil
,
videoErr
}
characterResult
,
videoErr
=
s
.
createCharacterFromVideo
(
reqCtx
,
account
,
videoData
,
characterOpts
)
if
videoErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
videoErr
,
reqModel
,
c
,
clientStream
)
}
if
characterResult
!=
nil
&&
characterOpts
.
DeleteAfterGenerate
&&
strings
.
TrimSpace
(
characterResult
.
CharacterID
)
!=
""
&&
!
characterOnly
{
characterID
:=
strings
.
TrimSpace
(
characterResult
.
CharacterID
)
defer
func
()
{
cleanupCtx
,
cancelCleanup
:=
context
.
WithTimeout
(
context
.
Background
(),
15
*
time
.
Second
)
defer
cancelCleanup
()
if
err
:=
s
.
soraClient
.
DeleteCharacter
(
cleanupCtx
,
account
,
characterID
);
err
!=
nil
{
log
.
Printf
(
"[Sora] cleanup character failed, character_id=%s err=%v"
,
characterID
,
err
)
}
}()
}
if
characterOnly
{
content
:=
"角色创建成功"
if
characterResult
!=
nil
&&
strings
.
TrimSpace
(
characterResult
.
Username
)
!=
""
{
content
=
fmt
.
Sprintf
(
"角色创建成功,角色名@%s"
,
strings
.
TrimSpace
(
characterResult
.
Username
))
}
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
resp
:=
buildSoraNonStreamResponse
(
content
,
reqModel
)
if
characterResult
!=
nil
{
resp
[
"character_id"
]
=
characterResult
.
CharacterID
resp
[
"cameo_id"
]
=
characterResult
.
CameoID
resp
[
"character_username"
]
=
characterResult
.
Username
resp
[
"character_display_name"
]
=
characterResult
.
DisplayName
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
}
return
&
ForwardResult
{
RequestID
:
""
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
"prompt"
,
},
nil
}
if
characterResult
!=
nil
&&
strings
.
TrimSpace
(
characterResult
.
Username
)
!=
""
{
prompt
=
fmt
.
Sprintf
(
"@%s %s"
,
characterResult
.
Username
,
prompt
)
}
}
var
imageData
[]
byte
imageFilename
:=
""
if
imageInput
!=
""
{
decoded
,
filename
,
err
:=
decodeSoraImageInput
(
reqCtx
,
imageInput
)
if
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
(),
clientStream
)
return
nil
,
err
}
imageData
=
decoded
imageFilename
=
filename
}
mediaID
:=
""
if
len
(
imageData
)
>
0
{
uploadID
,
err
:=
s
.
soraClient
.
UploadImage
(
reqCtx
,
account
,
imageData
,
imageFilename
)
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
mediaID
=
uploadID
}
taskID
:=
""
var
err
error
videoCount
:=
parseSoraVideoCount
(
reqBody
)
switch
modelCfg
.
Type
{
case
"image"
:
taskID
,
err
=
s
.
soraClient
.
CreateImageTask
(
reqCtx
,
account
,
SoraImageRequest
{
Prompt
:
prompt
,
Width
:
modelCfg
.
Width
,
Height
:
modelCfg
.
Height
,
MediaID
:
mediaID
,
})
case
"video"
:
if
remixTargetID
==
""
&&
isSoraStoryboardPrompt
(
prompt
)
{
taskID
,
err
=
s
.
soraClient
.
CreateStoryboardTask
(
reqCtx
,
account
,
SoraStoryboardRequest
{
Prompt
:
formatSoraStoryboardPrompt
(
prompt
),
Orientation
:
modelCfg
.
Orientation
,
Frames
:
modelCfg
.
Frames
,
Model
:
modelCfg
.
Model
,
Size
:
modelCfg
.
Size
,
MediaID
:
mediaID
,
})
}
else
{
taskID
,
err
=
s
.
soraClient
.
CreateVideoTask
(
reqCtx
,
account
,
SoraVideoRequest
{
Prompt
:
prompt
,
Orientation
:
modelCfg
.
Orientation
,
Frames
:
modelCfg
.
Frames
,
Model
:
modelCfg
.
Model
,
Size
:
modelCfg
.
Size
,
VideoCount
:
videoCount
,
MediaID
:
mediaID
,
RemixTargetID
:
remixTargetID
,
CameoIDs
:
extractSoraCameoIDs
(
reqBody
),
})
}
default
:
err
=
fmt
.
Errorf
(
"unsupported model type: %s"
,
modelCfg
.
Type
)
}
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
if
clientStream
&&
c
!=
nil
{
s
.
prepareSoraStream
(
c
,
taskID
)
}
var
mediaURLs
[]
string
videoGenerationID
:=
""
mediaType
:=
modelCfg
.
Type
imageCount
:=
0
imageSize
:=
""
switch
modelCfg
.
Type
{
case
"image"
:
urls
,
pollErr
:=
s
.
pollImageTask
(
reqCtx
,
c
,
account
,
taskID
,
clientStream
)
if
pollErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
pollErr
,
reqModel
,
c
,
clientStream
)
}
mediaURLs
=
urls
imageCount
=
len
(
urls
)
imageSize
=
soraImageSizeFromModel
(
reqModel
)
case
"video"
:
videoStatus
,
pollErr
:=
s
.
pollVideoTaskDetailed
(
reqCtx
,
c
,
account
,
taskID
,
clientStream
)
if
pollErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
pollErr
,
reqModel
,
c
,
clientStream
)
}
if
videoStatus
!=
nil
{
mediaURLs
=
videoStatus
.
URLs
videoGenerationID
=
strings
.
TrimSpace
(
videoStatus
.
GenerationID
)
}
default
:
mediaType
=
"prompt"
}
watermarkPostID
:=
""
if
modelCfg
.
Type
==
"video"
&&
watermarkOpts
.
Enabled
{
watermarkURL
,
postID
,
watermarkErr
:=
s
.
resolveWatermarkFreeURL
(
reqCtx
,
account
,
videoGenerationID
,
watermarkOpts
)
if
watermarkErr
!=
nil
{
if
!
watermarkOpts
.
FallbackOnFailure
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
watermarkErr
,
reqModel
,
c
,
clientStream
)
}
log
.
Printf
(
"[Sora] watermark-free fallback to original URL, task_id=%s err=%v"
,
taskID
,
watermarkErr
)
}
else
if
strings
.
TrimSpace
(
watermarkURL
)
!=
""
{
mediaURLs
=
[]
string
{
strings
.
TrimSpace
(
watermarkURL
)}
watermarkPostID
=
strings
.
TrimSpace
(
postID
)
}
}
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs
:=
s
.
normalizeSoraMediaURLs
(
mediaURLs
)
if
watermarkPostID
!=
""
&&
watermarkOpts
.
DeletePost
{
if
deleteErr
:=
s
.
soraClient
.
DeletePost
(
reqCtx
,
account
,
watermarkPostID
);
deleteErr
!=
nil
{
log
.
Printf
(
"[Sora] delete post failed, post_id=%s err=%v"
,
watermarkPostID
,
deleteErr
)
}
}
content
:=
buildSoraContent
(
mediaType
,
finalURLs
)
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
response
:=
buildSoraNonStreamResponse
(
content
,
reqModel
)
if
len
(
finalURLs
)
>
0
{
response
[
"media_url"
]
=
finalURLs
[
0
]
if
len
(
finalURLs
)
>
1
{
response
[
"media_urls"
]
=
finalURLs
}
}
c
.
JSON
(
http
.
StatusOK
,
response
)
}
return
&
ForwardResult
{
RequestID
:
taskID
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
mediaType
,
MediaURL
:
firstMediaURL
(
finalURLs
),
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
func
(
s
*
SoraGatewayService
)
withSoraTimeout
(
ctx
context
.
Context
,
stream
bool
)
(
context
.
Context
,
context
.
CancelFunc
)
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
ctx
,
nil
}
timeoutSeconds
:=
s
.
cfg
.
Gateway
.
SoraRequestTimeoutSeconds
if
stream
{
timeoutSeconds
=
s
.
cfg
.
Gateway
.
SoraStreamTimeoutSeconds
}
if
timeoutSeconds
<=
0
{
return
ctx
,
nil
}
return
context
.
WithTimeout
(
ctx
,
time
.
Duration
(
timeoutSeconds
)
*
time
.
Second
)
}
func
parseSoraWatermarkOptions
(
body
map
[
string
]
any
)
soraWatermarkOptions
{
opts
:=
soraWatermarkOptions
{
Enabled
:
parseBoolWithDefault
(
body
,
"watermark_free"
,
false
),
ParseMethod
:
strings
.
ToLower
(
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_method"
,
"third_party"
))),
ParseURL
:
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_url"
,
""
)),
ParseToken
:
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_token"
,
""
)),
FallbackOnFailure
:
parseBoolWithDefault
(
body
,
"watermark_fallback_on_failure"
,
true
),
DeletePost
:
parseBoolWithDefault
(
body
,
"watermark_delete_post"
,
false
),
}
if
opts
.
ParseMethod
==
""
{
opts
.
ParseMethod
=
"third_party"
}
return
opts
}
func
parseSoraCharacterOptions
(
body
map
[
string
]
any
)
soraCharacterOptions
{
return
soraCharacterOptions
{
SetPublic
:
parseBoolWithDefault
(
body
,
"character_set_public"
,
true
),
DeleteAfterGenerate
:
parseBoolWithDefault
(
body
,
"character_delete_after_generate"
,
true
),
}
}
func
parseSoraVideoCount
(
body
map
[
string
]
any
)
int
{
if
body
==
nil
{
return
1
}
keys
:=
[]
string
{
"video_count"
,
"videos"
,
"n_variants"
}
for
_
,
key
:=
range
keys
{
count
:=
parseIntWithDefault
(
body
,
key
,
0
)
if
count
>
0
{
return
clampInt
(
count
,
1
,
3
)
}
}
return
1
}
func
parseBoolWithDefault
(
body
map
[
string
]
any
,
key
string
,
def
bool
)
bool
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
switch
typed
:=
val
.
(
type
)
{
case
bool
:
return
typed
case
int
:
return
typed
!=
0
case
int32
:
return
typed
!=
0
case
int64
:
return
typed
!=
0
case
float64
:
return
typed
!=
0
case
string
:
typed
=
strings
.
ToLower
(
strings
.
TrimSpace
(
typed
))
if
typed
==
"true"
||
typed
==
"1"
||
typed
==
"yes"
{
return
true
}
if
typed
==
"false"
||
typed
==
"0"
||
typed
==
"no"
{
return
false
}
}
return
def
}
func
parseStringWithDefault
(
body
map
[
string
]
any
,
key
,
def
string
)
string
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
if
str
,
ok
:=
val
.
(
string
);
ok
{
return
str
}
return
def
}
func
parseIntWithDefault
(
body
map
[
string
]
any
,
key
string
,
def
int
)
int
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
switch
typed
:=
val
.
(
type
)
{
case
int
:
return
typed
case
int32
:
return
int
(
typed
)
case
int64
:
return
int
(
typed
)
case
float64
:
return
int
(
typed
)
case
string
:
parsed
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
typed
))
if
err
==
nil
{
return
parsed
}
}
return
def
}
func
clampInt
(
v
,
minVal
,
maxVal
int
)
int
{
if
v
<
minVal
{
return
minVal
}
if
v
>
maxVal
{
return
maxVal
}
return
v
}
func
extractSoraCameoIDs
(
body
map
[
string
]
any
)
[]
string
{
if
body
==
nil
{
return
nil
}
raw
,
ok
:=
body
[
"cameo_ids"
]
if
!
ok
{
return
nil
}
switch
typed
:=
raw
.
(
type
)
{
case
[]
string
:
out
:=
make
([]
string
,
0
,
len
(
typed
))
for
_
,
item
:=
range
typed
{
item
=
strings
.
TrimSpace
(
item
)
if
item
!=
""
{
out
=
append
(
out
,
item
)
}
}
return
out
case
[]
any
:
out
:=
make
([]
string
,
0
,
len
(
typed
))
for
_
,
item
:=
range
typed
{
str
,
ok
:=
item
.
(
string
)
if
!
ok
{
continue
}
str
=
strings
.
TrimSpace
(
str
)
if
str
!=
""
{
out
=
append
(
out
,
str
)
}
}
return
out
default
:
return
nil
}
}
func
(
s
*
SoraGatewayService
)
createCharacterFromVideo
(
ctx
context
.
Context
,
account
*
Account
,
videoData
[]
byte
,
opts
soraCharacterOptions
)
(
*
soraCharacterFlowResult
,
error
)
{
cameoID
,
err
:=
s
.
soraClient
.
UploadCharacterVideo
(
ctx
,
account
,
videoData
)
if
err
!=
nil
{
return
nil
,
err
}
cameoStatus
,
err
:=
s
.
pollCameoStatus
(
ctx
,
account
,
cameoID
)
if
err
!=
nil
{
return
nil
,
err
}
username
:=
processSoraCharacterUsername
(
cameoStatus
.
UsernameHint
)
displayName
:=
strings
.
TrimSpace
(
cameoStatus
.
DisplayNameHint
)
if
displayName
==
""
{
displayName
=
"Character"
}
profileAssetURL
:=
strings
.
TrimSpace
(
cameoStatus
.
ProfileAssetURL
)
if
profileAssetURL
==
""
{
return
nil
,
errors
.
New
(
"profile asset url not found in cameo status"
)
}
avatarData
,
err
:=
s
.
soraClient
.
DownloadCharacterImage
(
ctx
,
account
,
profileAssetURL
)
if
err
!=
nil
{
return
nil
,
err
}
assetPointer
,
err
:=
s
.
soraClient
.
UploadCharacterImage
(
ctx
,
account
,
avatarData
)
if
err
!=
nil
{
return
nil
,
err
}
instructionSet
:=
cameoStatus
.
InstructionSetHint
if
instructionSet
==
nil
{
instructionSet
=
cameoStatus
.
InstructionSet
}
characterID
,
err
:=
s
.
soraClient
.
FinalizeCharacter
(
ctx
,
account
,
SoraCharacterFinalizeRequest
{
CameoID
:
strings
.
TrimSpace
(
cameoID
),
Username
:
username
,
DisplayName
:
displayName
,
ProfileAssetPointer
:
assetPointer
,
InstructionSet
:
instructionSet
,
})
if
err
!=
nil
{
return
nil
,
err
}
if
opts
.
SetPublic
{
if
err
:=
s
.
soraClient
.
SetCharacterPublic
(
ctx
,
account
,
cameoID
);
err
!=
nil
{
return
nil
,
err
}
}
return
&
soraCharacterFlowResult
{
CameoID
:
strings
.
TrimSpace
(
cameoID
),
CharacterID
:
strings
.
TrimSpace
(
characterID
),
Username
:
strings
.
TrimSpace
(
username
),
DisplayName
:
displayName
,
},
nil
}
func
(
s
*
SoraGatewayService
)
pollCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
timeout
:=
10
*
time
.
Minute
interval
:=
5
*
time
.
Second
maxAttempts
:=
int
(
math
.
Ceil
(
timeout
.
Seconds
()
/
interval
.
Seconds
()))
if
maxAttempts
<
1
{
maxAttempts
=
1
}
var
lastErr
error
consecutiveErrors
:=
0
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetCameoStatus
(
ctx
,
account
,
cameoID
)
if
err
!=
nil
{
lastErr
=
err
consecutiveErrors
++
if
consecutiveErrors
>=
3
{
break
}
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
continue
}
consecutiveErrors
=
0
if
status
==
nil
{
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
continue
}
currentStatus
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
status
.
Status
))
statusMessage
:=
strings
.
TrimSpace
(
status
.
StatusMessage
)
if
currentStatus
==
"failed"
{
if
statusMessage
==
""
{
statusMessage
=
"character creation failed"
}
return
nil
,
errors
.
New
(
statusMessage
)
}
if
strings
.
EqualFold
(
statusMessage
,
"Completed"
)
||
currentStatus
==
"finalized"
{
return
status
,
nil
}
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
}
if
lastErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"poll cameo status failed: %w"
,
lastErr
)
}
return
nil
,
errors
.
New
(
"cameo processing timeout"
)
}
func
processSoraCharacterUsername
(
usernameHint
string
)
string
{
usernameHint
=
strings
.
TrimSpace
(
usernameHint
)
if
usernameHint
==
""
{
usernameHint
=
"character"
}
if
strings
.
Contains
(
usernameHint
,
"."
)
{
parts
:=
strings
.
Split
(
usernameHint
,
"."
)
usernameHint
=
strings
.
TrimSpace
(
parts
[
len
(
parts
)
-
1
])
}
if
usernameHint
==
""
{
usernameHint
=
"character"
}
return
fmt
.
Sprintf
(
"%s%d"
,
usernameHint
,
rand
.
Intn
(
900
)
+
100
)
}
func
(
s
*
SoraGatewayService
)
resolveWatermarkFreeURL
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
,
opts
soraWatermarkOptions
)
(
string
,
string
,
error
)
{
generationID
=
strings
.
TrimSpace
(
generationID
)
if
generationID
==
""
{
return
""
,
""
,
errors
.
New
(
"generation id is required for watermark-free mode"
)
}
postID
,
err
:=
s
.
soraClient
.
PostVideoForWatermarkFree
(
ctx
,
account
,
generationID
)
if
err
!=
nil
{
return
""
,
""
,
err
}
postID
=
strings
.
TrimSpace
(
postID
)
if
postID
==
""
{
return
""
,
""
,
errors
.
New
(
"watermark-free publish returned empty post id"
)
}
switch
opts
.
ParseMethod
{
case
"custom"
:
urlVal
,
parseErr
:=
s
.
soraClient
.
GetWatermarkFreeURLCustom
(
ctx
,
account
,
opts
.
ParseURL
,
opts
.
ParseToken
,
postID
)
if
parseErr
!=
nil
{
return
""
,
postID
,
parseErr
}
return
strings
.
TrimSpace
(
urlVal
),
postID
,
nil
case
""
,
"third_party"
:
return
fmt
.
Sprintf
(
"https://oscdn2.dyysy.com/MP4/%s.mp4"
,
postID
),
postID
,
nil
default
:
return
""
,
postID
,
fmt
.
Errorf
(
"unsupported watermark parse method: %s"
,
opts
.
ParseMethod
)
}
}
func
(
s
*
SoraGatewayService
)
shouldFailoverUpstreamError
(
statusCode
int
)
bool
{
switch
statusCode
{
case
401
,
402
,
403
,
404
,
429
,
529
:
return
true
default
:
return
statusCode
>=
500
}
}
func
buildSoraNonStreamResponse
(
content
,
model
string
)
map
[
string
]
any
{
return
map
[
string
]
any
{
"id"
:
fmt
.
Sprintf
(
"chatcmpl-%d"
,
time
.
Now
()
.
UnixNano
()),
"object"
:
"chat.completion"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"message"
:
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
content
,
},
"finish_reason"
:
"stop"
,
},
},
}
}
func
soraImageSizeFromModel
(
model
string
)
string
{
modelLower
:=
strings
.
ToLower
(
model
)
if
size
,
ok
:=
soraImageSizeMap
[
modelLower
];
ok
{
return
size
}
if
strings
.
Contains
(
modelLower
,
"landscape"
)
||
strings
.
Contains
(
modelLower
,
"portrait"
)
{
return
"540"
}
return
"360"
}
func
soraProErrorMessage
(
model
,
upstreamMsg
string
)
string
{
modelLower
:=
strings
.
ToLower
(
model
)
if
strings
.
Contains
(
modelLower
,
"sora2pro-hd"
)
{
return
"当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
}
if
strings
.
Contains
(
modelLower
,
"sora2pro"
)
{
return
"当前账号无法使用 Sora Pro 模型,请更换模型或账号"
}
return
""
}
func
firstMediaURL
(
urls
[]
string
)
string
{
if
len
(
urls
)
==
0
{
return
""
}
return
urls
[
0
]
}
func
(
s
*
SoraGatewayService
)
buildSoraMediaURL
(
path
string
,
rawQuery
string
)
string
{
if
path
==
""
{
return
path
}
prefix
:=
"/sora/media"
values
:=
url
.
Values
{}
if
rawQuery
!=
""
{
if
parsed
,
err
:=
url
.
ParseQuery
(
rawQuery
);
err
==
nil
{
values
=
parsed
}
}
signKey
:=
""
ttlSeconds
:=
0
if
s
!=
nil
&&
s
.
cfg
!=
nil
{
signKey
=
strings
.
TrimSpace
(
s
.
cfg
.
Gateway
.
SoraMediaSigningKey
)
ttlSeconds
=
s
.
cfg
.
Gateway
.
SoraMediaSignedURLTTLSeconds
}
values
.
Del
(
"sig"
)
values
.
Del
(
"expires"
)
signingQuery
:=
values
.
Encode
()
if
signKey
!=
""
&&
ttlSeconds
>
0
{
expires
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
ttlSeconds
)
*
time
.
Second
)
.
Unix
()
signature
:=
SignSoraMediaURL
(
path
,
signingQuery
,
expires
,
signKey
)
if
signature
!=
""
{
values
.
Set
(
"expires"
,
strconv
.
FormatInt
(
expires
,
10
))
values
.
Set
(
"sig"
,
signature
)
prefix
=
"/sora/media-signed"
}
}
encoded
:=
values
.
Encode
()
if
encoded
==
""
{
return
prefix
+
path
}
return
prefix
+
path
+
"?"
+
encoded
}
func
(
s
*
SoraGatewayService
)
prepareSoraStream
(
c
*
gin
.
Context
,
requestID
string
)
{
if
c
==
nil
{
return
}
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
strings
.
TrimSpace
(
requestID
)
!=
""
{
c
.
Header
(
"x-request-id"
,
requestID
)
}
}
func
(
s
*
SoraGatewayService
)
writeSoraStream
(
c
*
gin
.
Context
,
model
,
content
string
,
startTime
time
.
Time
)
(
*
int
,
error
)
{
if
c
==
nil
{
return
nil
,
nil
}
writer
:=
c
.
Writer
flusher
,
_
:=
writer
.
(
http
.
Flusher
)
chunk
:=
map
[
string
]
any
{
"id"
:
fmt
.
Sprintf
(
"chatcmpl-%d"
,
time
.
Now
()
.
UnixNano
()),
"object"
:
"chat.completion.chunk"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
map
[
string
]
any
{
"content"
:
content
,
},
},
},
}
encoded
,
_
:=
jsonMarshalRaw
(
chunk
)
if
_
,
err
:=
fmt
.
Fprintf
(
writer
,
"data: %s
\n\n
"
,
encoded
);
err
!=
nil
{
return
nil
,
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
finalChunk
:=
map
[
string
]
any
{
"id"
:
chunk
[
"id"
],
"object"
:
"chat.completion.chunk"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
map
[
string
]
any
{},
"finish_reason"
:
"stop"
,
},
},
}
finalEncoded
,
_
:=
jsonMarshalRaw
(
finalChunk
)
if
_
,
err
:=
fmt
.
Fprintf
(
writer
,
"data: %s
\n\n
"
,
finalEncoded
);
err
!=
nil
{
return
&
ms
,
err
}
if
_
,
err
:=
fmt
.
Fprint
(
writer
,
"data: [DONE]
\n\n
"
);
err
!=
nil
{
return
&
ms
,
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
&
ms
,
nil
}
func
(
s
*
SoraGatewayService
)
writeSoraError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
,
stream
bool
)
{
if
c
==
nil
{
return
}
if
stream
{
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
errorData
:=
map
[
string
]
any
{
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"event: error
\n
data: %s
\n\n
"
,
string
(
jsonBytes
))
_
,
_
=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
)
_
,
_
=
fmt
.
Fprint
(
c
.
Writer
,
"data: [DONE]
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
message
,
},
})
}
func
(
s
*
SoraGatewayService
)
handleSoraRequestError
(
ctx
context
.
Context
,
account
*
Account
,
err
error
,
model
string
,
c
*
gin
.
Context
,
stream
bool
)
error
{
if
err
==
nil
{
return
nil
}
var
upstreamErr
*
SoraUpstreamError
if
errors
.
As
(
err
,
&
upstreamErr
)
{
accountID
:=
int64
(
0
)
if
account
!=
nil
{
accountID
=
account
.
ID
}
logger
.
LegacyPrintf
(
"service.sora"
,
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s"
,
accountID
,
model
,
upstreamErr
.
StatusCode
,
strings
.
TrimSpace
(
upstreamErr
.
Headers
.
Get
(
"x-request-id"
)),
strings
.
TrimSpace
(
upstreamErr
.
Headers
.
Get
(
"cf-ray"
)),
strings
.
TrimSpace
(
upstreamErr
.
Message
),
truncateForLog
(
upstreamErr
.
Body
,
1024
),
)
if
s
.
rateLimitService
!=
nil
&&
account
!=
nil
{
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
upstreamErr
.
StatusCode
,
upstreamErr
.
Headers
,
upstreamErr
.
Body
)
}
if
s
.
shouldFailoverUpstreamError
(
upstreamErr
.
StatusCode
)
{
var
responseHeaders
http
.
Header
if
upstreamErr
.
Headers
!=
nil
{
responseHeaders
=
upstreamErr
.
Headers
.
Clone
()
}
return
&
UpstreamFailoverError
{
StatusCode
:
upstreamErr
.
StatusCode
,
ResponseBody
:
upstreamErr
.
Body
,
ResponseHeaders
:
responseHeaders
,
}
}
msg
:=
upstreamErr
.
Message
if
override
:=
soraProErrorMessage
(
model
,
msg
);
override
!=
""
{
msg
=
override
}
s
.
writeSoraError
(
c
,
upstreamErr
.
StatusCode
,
"upstream_error"
,
msg
,
stream
)
return
err
}
if
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
s
.
writeSoraError
(
c
,
http
.
StatusGatewayTimeout
,
"timeout_error"
,
"Sora generation timeout"
,
stream
)
return
err
}
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"api_error"
,
err
.
Error
(),
stream
)
return
err
}
func
(
s
*
SoraGatewayService
)
pollImageTask
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
taskID
string
,
stream
bool
)
([]
string
,
error
)
{
interval
:=
s
.
pollInterval
()
maxAttempts
:=
s
.
pollMaxAttempts
()
lastPing
:=
time
.
Now
()
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetImageTask
(
ctx
,
account
,
taskID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
strings
.
ToLower
(
status
.
Status
)
{
case
"succeeded"
,
"completed"
:
return
status
.
URLs
,
nil
case
"failed"
:
if
status
.
ErrorMsg
!=
""
{
return
nil
,
errors
.
New
(
status
.
ErrorMsg
)
}
return
nil
,
errors
.
New
(
"sora image generation failed"
)
}
if
stream
{
s
.
maybeSendPing
(
c
,
&
lastPing
)
}
if
err
:=
sleepWithContext
(
ctx
,
interval
);
err
!=
nil
{
return
nil
,
err
}
}
return
nil
,
errors
.
New
(
"sora image generation timeout"
)
}
func
(
s
*
SoraGatewayService
)
pollVideoTaskDetailed
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
taskID
string
,
stream
bool
)
(
*
SoraVideoTaskStatus
,
error
)
{
interval
:=
s
.
pollInterval
()
maxAttempts
:=
s
.
pollMaxAttempts
()
lastPing
:=
time
.
Now
()
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetVideoTask
(
ctx
,
account
,
taskID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
strings
.
ToLower
(
status
.
Status
)
{
case
"completed"
,
"succeeded"
:
return
status
,
nil
case
"failed"
:
if
status
.
ErrorMsg
!=
""
{
return
nil
,
errors
.
New
(
status
.
ErrorMsg
)
}
return
nil
,
errors
.
New
(
"sora video generation failed"
)
}
if
stream
{
s
.
maybeSendPing
(
c
,
&
lastPing
)
}
if
err
:=
sleepWithContext
(
ctx
,
interval
);
err
!=
nil
{
return
nil
,
err
}
}
return
nil
,
errors
.
New
(
"sora video generation timeout"
)
}
func
(
s
*
SoraGatewayService
)
pollInterval
()
time
.
Duration
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
2
*
time
.
Second
}
interval
:=
s
.
cfg
.
Sora
.
Client
.
PollIntervalSeconds
if
interval
<=
0
{
interval
=
2
}
return
time
.
Duration
(
interval
)
*
time
.
Second
}
func
(
s
*
SoraGatewayService
)
pollMaxAttempts
()
int
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
600
}
maxAttempts
:=
s
.
cfg
.
Sora
.
Client
.
MaxPollAttempts
if
maxAttempts
<=
0
{
maxAttempts
=
600
}
return
maxAttempts
}
func
(
s
*
SoraGatewayService
)
maybeSendPing
(
c
*
gin
.
Context
,
lastPing
*
time
.
Time
)
{
if
c
==
nil
{
return
}
interval
:=
10
*
time
.
Second
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Concurrency
.
PingInterval
>
0
{
interval
=
time
.
Duration
(
s
.
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
}
if
time
.
Since
(
*
lastPing
)
<
interval
{
return
}
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
":
\n\n
"
);
err
==
nil
{
if
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
);
ok
{
flusher
.
Flush
()
}
*
lastPing
=
time
.
Now
()
}
}
func
(
s
*
SoraGatewayService
)
normalizeSoraMediaURLs
(
urls
[]
string
)
[]
string
{
if
len
(
urls
)
==
0
{
return
urls
}
output
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
raw
:=
range
urls
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
continue
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
output
=
append
(
output
,
raw
)
continue
}
pathVal
:=
raw
if
!
strings
.
HasPrefix
(
pathVal
,
"/"
)
{
pathVal
=
"/"
+
pathVal
}
output
=
append
(
output
,
s
.
buildSoraMediaURL
(
pathVal
,
""
))
}
return
output
}
// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
func
jsonMarshalRaw
(
v
any
)
([]
byte
,
error
)
{
var
buf
bytes
.
Buffer
enc
:=
json
.
NewEncoder
(
&
buf
)
enc
.
SetEscapeHTML
(
false
)
if
err
:=
enc
.
Encode
(
v
);
err
!=
nil
{
return
nil
,
err
}
// Encode 会追加换行符,去掉它
b
:=
buf
.
Bytes
()
if
len
(
b
)
>
0
&&
b
[
len
(
b
)
-
1
]
==
'\n'
{
b
=
b
[
:
len
(
b
)
-
1
]
}
return
b
,
nil
}
func
buildSoraContent
(
mediaType
string
,
urls
[]
string
)
string
{
switch
mediaType
{
case
"image"
:
parts
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
u
:=
range
urls
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
""
,
u
))
}
return
strings
.
Join
(
parts
,
"
\n
"
)
case
"video"
:
if
len
(
urls
)
==
0
{
return
""
}
return
fmt
.
Sprintf
(
"```html
\n
<video src='%s' controls></video>
\n
```"
,
urls
[
0
])
default
:
return
""
}
}
func
extractSoraInput
(
body
map
[
string
]
any
)
(
prompt
,
imageInput
,
videoInput
,
remixTargetID
string
)
{
if
body
==
nil
{
return
""
,
""
,
""
,
""
}
if
v
,
ok
:=
body
[
"remix_target_id"
]
.
(
string
);
ok
{
remixTargetID
=
strings
.
TrimSpace
(
v
)
}
if
v
,
ok
:=
body
[
"image"
]
.
(
string
);
ok
{
imageInput
=
v
}
if
v
,
ok
:=
body
[
"video"
]
.
(
string
);
ok
{
videoInput
=
v
}
if
v
,
ok
:=
body
[
"prompt"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
v
)
!=
""
{
prompt
=
v
}
if
messages
,
ok
:=
body
[
"messages"
]
.
([]
any
);
ok
{
builder
:=
strings
.
Builder
{}
for
_
,
raw
:=
range
messages
{
msg
,
ok
:=
raw
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
role
,
_
:=
msg
[
"role"
]
.
(
string
)
if
role
!=
""
&&
role
!=
"user"
{
continue
}
content
:=
msg
[
"content"
]
text
,
img
,
vid
:=
parseSoraMessageContent
(
content
)
if
text
!=
""
{
if
builder
.
Len
()
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n
"
)
}
_
,
_
=
builder
.
WriteString
(
text
)
}
if
imageInput
==
""
&&
img
!=
""
{
imageInput
=
img
}
if
videoInput
==
""
&&
vid
!=
""
{
videoInput
=
vid
}
}
if
prompt
==
""
{
prompt
=
builder
.
String
()
}
}
if
remixTargetID
==
""
{
remixTargetID
=
extractRemixTargetIDFromPrompt
(
prompt
)
}
prompt
=
cleanRemixLinkFromPrompt
(
prompt
)
return
prompt
,
imageInput
,
videoInput
,
remixTargetID
}
func
parseSoraMessageContent
(
content
any
)
(
text
,
imageInput
,
videoInput
string
)
{
switch
val
:=
content
.
(
type
)
{
case
string
:
return
val
,
""
,
""
case
[]
any
:
builder
:=
strings
.
Builder
{}
for
_
,
item
:=
range
val
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
t
,
_
:=
itemMap
[
"type"
]
.
(
string
)
switch
t
{
case
"text"
:
if
txt
,
ok
:=
itemMap
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
txt
)
!=
""
{
if
builder
.
Len
()
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n
"
)
}
_
,
_
=
builder
.
WriteString
(
txt
)
}
case
"image_url"
:
if
imageInput
==
""
{
if
urlVal
,
ok
:=
itemMap
[
"image_url"
]
.
(
map
[
string
]
any
);
ok
{
imageInput
=
fmt
.
Sprintf
(
"%v"
,
urlVal
[
"url"
])
}
else
if
urlStr
,
ok
:=
itemMap
[
"image_url"
]
.
(
string
);
ok
{
imageInput
=
urlStr
}
}
case
"video_url"
:
if
videoInput
==
""
{
if
urlVal
,
ok
:=
itemMap
[
"video_url"
]
.
(
map
[
string
]
any
);
ok
{
videoInput
=
fmt
.
Sprintf
(
"%v"
,
urlVal
[
"url"
])
}
else
if
urlStr
,
ok
:=
itemMap
[
"video_url"
]
.
(
string
);
ok
{
videoInput
=
urlStr
}
}
}
}
return
builder
.
String
(),
imageInput
,
videoInput
default
:
return
""
,
""
,
""
}
}
func
isSoraStoryboardPrompt
(
prompt
string
)
bool
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
false
}
return
len
(
soraStoryboardPattern
.
FindAllString
(
prompt
,
-
1
))
>=
1
}
func
formatSoraStoryboardPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
""
}
matches
:=
soraStoryboardShotPattern
.
FindAllStringSubmatch
(
prompt
,
-
1
)
if
len
(
matches
)
==
0
{
return
prompt
}
firstBracketPos
:=
strings
.
Index
(
prompt
,
"["
)
instructions
:=
""
if
firstBracketPos
>
0
{
instructions
=
strings
.
TrimSpace
(
prompt
[
:
firstBracketPos
])
}
shots
:=
make
([]
string
,
0
,
len
(
matches
))
for
i
,
match
:=
range
matches
{
if
len
(
match
)
<
3
{
continue
}
duration
:=
strings
.
TrimSpace
(
match
[
1
])
scene
:=
strings
.
TrimSpace
(
match
[
2
])
if
scene
==
""
{
continue
}
shots
=
append
(
shots
,
fmt
.
Sprintf
(
"Shot %d:
\n
duration: %ssec
\n
Scene: %s"
,
i
+
1
,
duration
,
scene
))
}
if
len
(
shots
)
==
0
{
return
prompt
}
timeline
:=
strings
.
Join
(
shots
,
"
\n\n
"
)
if
instructions
==
""
{
return
timeline
}
return
fmt
.
Sprintf
(
"current timeline:
\n
%s
\n\n
instructions:
\n
%s"
,
timeline
,
instructions
)
}
func
extractRemixTargetIDFromPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
""
}
return
strings
.
TrimSpace
(
soraRemixTargetPattern
.
FindString
(
prompt
))
}
func
cleanRemixLinkFromPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
prompt
}
cleaned
:=
soraRemixTargetInURLPattern
.
ReplaceAllString
(
prompt
,
""
)
cleaned
=
soraRemixTargetPattern
.
ReplaceAllString
(
cleaned
,
""
)
cleaned
=
strings
.
Join
(
strings
.
Fields
(
cleaned
),
" "
)
return
strings
.
TrimSpace
(
cleaned
)
}
func
decodeSoraImageInput
(
ctx
context
.
Context
,
input
string
)
([]
byte
,
string
,
error
)
{
raw
:=
strings
.
TrimSpace
(
input
)
if
raw
==
""
{
return
nil
,
""
,
errors
.
New
(
"empty image input"
)
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
parts
:=
strings
.
SplitN
(
raw
,
","
,
2
)
if
len
(
parts
)
!=
2
{
return
nil
,
""
,
errors
.
New
(
"invalid data url"
)
}
meta
:=
parts
[
0
]
payload
:=
parts
[
1
]
decoded
,
err
:=
decodeBase64WithLimit
(
payload
,
soraImageInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
ext
:=
""
if
strings
.
HasPrefix
(
meta
,
"data:"
)
{
metaParts
:=
strings
.
SplitN
(
meta
[
5
:
],
";"
,
2
)
if
len
(
metaParts
)
>
0
{
if
exts
,
err
:=
mime
.
ExtensionsByType
(
metaParts
[
0
]);
err
==
nil
&&
len
(
exts
)
>
0
{
ext
=
exts
[
0
]
}
}
}
filename
:=
"image"
+
ext
return
decoded
,
filename
,
nil
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
return
downloadSoraImageInput
(
ctx
,
raw
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
raw
,
soraImageInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
""
,
errors
.
New
(
"invalid base64 image"
)
}
return
decoded
,
"image.png"
,
nil
}
func
decodeSoraVideoInput
(
ctx
context
.
Context
,
input
string
)
([]
byte
,
error
)
{
raw
:=
strings
.
TrimSpace
(
input
)
if
raw
==
""
{
return
nil
,
errors
.
New
(
"empty video input"
)
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
parts
:=
strings
.
SplitN
(
raw
,
","
,
2
)
if
len
(
parts
)
!=
2
{
return
nil
,
errors
.
New
(
"invalid video data url"
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
parts
[
1
],
soraVideoInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
errors
.
New
(
"invalid base64 video"
)
}
if
len
(
decoded
)
==
0
{
return
nil
,
errors
.
New
(
"empty video data"
)
}
return
decoded
,
nil
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
return
downloadSoraVideoInput
(
ctx
,
raw
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
raw
,
soraVideoInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
errors
.
New
(
"invalid base64 video"
)
}
if
len
(
decoded
)
==
0
{
return
nil
,
errors
.
New
(
"empty video data"
)
}
return
decoded
,
nil
}
func
downloadSoraImageInput
(
ctx
context
.
Context
,
rawURL
string
)
([]
byte
,
string
,
error
)
{
parsed
,
err
:=
validateSoraRemoteURL
(
rawURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
parsed
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
soraImageInputTimeout
,
CheckRedirect
:
func
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
soraImageInputMaxRedirects
{
return
errors
.
New
(
"too many redirects"
)
}
return
validateSoraRemoteURLValue
(
req
.
URL
)
},
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
""
,
fmt
.
Errorf
(
"download image failed: %d"
,
resp
.
StatusCode
)
}
data
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
soraImageInputMaxBytes
))
if
err
!=
nil
{
return
nil
,
""
,
err
}
ext
:=
fileExtFromURL
(
parsed
.
String
())
if
ext
==
""
{
ext
=
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
))
}
filename
:=
"image"
+
ext
return
data
,
filename
,
nil
}
func
downloadSoraVideoInput
(
ctx
context
.
Context
,
rawURL
string
)
([]
byte
,
error
)
{
parsed
,
err
:=
validateSoraRemoteURL
(
rawURL
)
if
err
!=
nil
{
return
nil
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
parsed
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
soraVideoInputTimeout
,
CheckRedirect
:
func
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
soraVideoInputMaxRedirects
{
return
errors
.
New
(
"too many redirects"
)
}
return
validateSoraRemoteURLValue
(
req
.
URL
)
},
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"download video failed: %d"
,
resp
.
StatusCode
)
}
data
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
soraVideoInputMaxBytes
))
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
data
)
==
0
{
return
nil
,
errors
.
New
(
"empty video content"
)
}
return
data
,
nil
}
func
decodeBase64WithLimit
(
encoded
string
,
maxBytes
int64
)
([]
byte
,
error
)
{
if
maxBytes
<=
0
{
return
nil
,
errors
.
New
(
"invalid max bytes limit"
)
}
decoder
:=
base64
.
NewDecoder
(
base64
.
StdEncoding
,
strings
.
NewReader
(
encoded
))
limited
:=
io
.
LimitReader
(
decoder
,
maxBytes
+
1
)
data
,
err
:=
io
.
ReadAll
(
limited
)
if
err
!=
nil
{
return
nil
,
err
}
if
int64
(
len
(
data
))
>
maxBytes
{
return
nil
,
fmt
.
Errorf
(
"input exceeds %d bytes limit"
,
maxBytes
)
}
return
data
,
nil
}
func
validateSoraRemoteURL
(
raw
string
)
(
*
url
.
URL
,
error
)
{
if
strings
.
TrimSpace
(
raw
)
==
""
{
return
nil
,
errors
.
New
(
"empty remote url"
)
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid remote url: %w"
,
err
)
}
if
err
:=
validateSoraRemoteURLValue
(
parsed
);
err
!=
nil
{
return
nil
,
err
}
return
parsed
,
nil
}
func
validateSoraRemoteURLValue
(
parsed
*
url
.
URL
)
error
{
if
parsed
==
nil
{
return
errors
.
New
(
"invalid remote url"
)
}
scheme
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
parsed
.
Scheme
))
if
scheme
!=
"http"
&&
scheme
!=
"https"
{
return
errors
.
New
(
"only http/https remote url is allowed"
)
}
if
parsed
.
User
!=
nil
{
return
errors
.
New
(
"remote url cannot contain userinfo"
)
}
host
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
parsed
.
Hostname
()))
if
host
==
""
{
return
errors
.
New
(
"remote url missing host"
)
}
if
_
,
blocked
:=
soraBlockedHostnames
[
host
];
blocked
{
return
errors
.
New
(
"remote url is not allowed"
)
}
if
ip
:=
net
.
ParseIP
(
host
);
ip
!=
nil
{
if
isSoraBlockedIP
(
ip
)
{
return
errors
.
New
(
"remote url is not allowed"
)
}
return
nil
}
ips
,
err
:=
net
.
LookupIP
(
host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"resolve remote url failed: %w"
,
err
)
}
for
_
,
ip
:=
range
ips
{
if
isSoraBlockedIP
(
ip
)
{
return
errors
.
New
(
"remote url is not allowed"
)
}
}
return
nil
}
func
isSoraBlockedIP
(
ip
net
.
IP
)
bool
{
if
ip
==
nil
{
return
true
}
for
_
,
cidr
:=
range
soraBlockedCIDRs
{
if
cidr
.
Contains
(
ip
)
{
return
true
}
}
return
false
}
func
mustParseCIDRs
(
values
[]
string
)
[]
*
net
.
IPNet
{
out
:=
make
([]
*
net
.
IPNet
,
0
,
len
(
values
))
for
_
,
val
:=
range
values
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
val
)
if
err
!=
nil
{
continue
}
out
=
append
(
out
,
cidr
)
}
return
out
}
backend/internal/service/sora_gateway_service_test.go
deleted
100644 → 0
View file @
dbb248df
//go:build unit
package
service
import
(
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var
_
SoraClient
=
(
*
stubSoraClientForPoll
)(
nil
)
type
stubSoraClientForPoll
struct
{
imageStatus
*
SoraImageTaskStatus
videoStatus
*
SoraVideoTaskStatus
imageCalls
int
videoCalls
int
enhanced
string
enhanceErr
error
storyboard
bool
videoReq
SoraVideoRequest
parseErr
error
postCalls
int
deleteCalls
int
}
func
(
s
*
stubSoraClientForPoll
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClientForPoll
)
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
{
s
.
videoReq
=
req
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
{
s
.
storyboard
=
true
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"cameo-1"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
return
&
SoraCameoStatus
{
Status
:
"finalized"
,
StatusMessage
:
"Completed"
,
DisplayNameHint
:
"Character"
,
UsernameHint
:
"user.character"
,
ProfileAssetURL
:
"https://example.com/avatar.webp"
,
},
nil
}
func
(
s
*
stubSoraClientForPoll
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
{
return
[]
byte
(
"avatar"
),
nil
}
func
(
s
*
stubSoraClientForPoll
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"asset-pointer"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
"character-1"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
{
s
.
postCalls
++
return
"s_post"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
{
s
.
deleteCalls
++
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
if
s
.
parseErr
!=
nil
{
return
""
,
s
.
parseErr
}
return
"https://example.com/no-watermark.mp4"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
if
s
.
enhanced
!=
""
{
return
s
.
enhanced
,
s
.
enhanceErr
}
return
"enhanced prompt"
,
s
.
enhanceErr
}
func
(
s
*
stubSoraClientForPoll
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
{
s
.
imageCalls
++
return
s
.
imageStatus
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
{
s
.
videoCalls
++
return
s
.
videoStatus
,
nil
}
func
TestSoraGatewayService_PollImageTaskCompleted
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
imageStatus
:
&
SoraImageTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/a.png"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
service
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
urls
,
err
:=
service
.
pollImageTask
(
context
.
Background
(),
nil
,
&
Account
{
ID
:
1
},
"task"
,
false
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"https://example.com/a.png"
},
urls
)
require
.
Equal
(
t
,
1
,
client
.
imageCalls
)
}
func
TestSoraGatewayService_ForwardPromptEnhance
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
enhanced
:
"cinematic prompt"
,
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"prompt-enhance-short-10s"
:
"prompt-enhance-short-15s"
,
},
},
}
body
:=
[]
byte
(
`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"prompt"
,
result
.
MediaType
)
require
.
Equal
(
t
,
"prompt-enhance-short-10s"
,
result
.
Model
)
require
.
Equal
(
t
,
"prompt-enhance-short-15s"
,
result
.
UpstreamModel
)
}
func
TestSoraGatewayService_ForwardStoryboardPrompt
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/v.mp4"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
client
.
storyboard
)
}
func
TestSoraGatewayService_ForwardVideoCount
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/v.mp4"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
3
,
client
.
videoReq
.
VideoCount
)
}
func
TestSoraGatewayService_ForwardCharacterOnly
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"prompt"
,
result
.
MediaType
)
require
.
Equal
(
t
,
0
,
client
.
videoCalls
)
}
func
TestSoraGatewayService_ForwardWatermarkFallback
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/original.mp4"
},
GenerationID
:
"gen_1"
,
},
parseErr
:
errors
.
New
(
"parse failed"
),
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"https://example.com/original.mp4"
,
result
.
MediaURL
)
require
.
Equal
(
t
,
1
,
client
.
postCalls
)
require
.
Equal
(
t
,
0
,
client
.
deleteCalls
)
}
func
TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/original.mp4"
},
GenerationID
:
"gen_1"
,
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"https://example.com/no-watermark.mp4"
,
result
.
MediaURL
)
require
.
Equal
(
t
,
1
,
client
.
postCalls
)
require
.
Equal
(
t
,
1
,
client
.
deleteCalls
)
}
func
TestSoraGatewayService_PollVideoTaskFailed
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"reject"
,
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
service
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
status
,
err
:=
service
.
pollVideoTaskDetailed
(
context
.
Background
(),
nil
,
&
Account
{
ID
:
1
},
"task"
,
false
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
status
)
require
.
Contains
(
t
,
err
.
Error
(),
"reject"
)
require
.
Equal
(
t
,
1
,
client
.
videoCalls
)
}
func
TestSoraGatewayService_BuildSoraMediaURLSigned
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
SoraMediaSigningKey
:
"test-key"
,
SoraMediaSignedURLTTLSeconds
:
600
,
},
}
service
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
url
:=
service
.
buildSoraMediaURL
(
"/image/2025/01/01/a.png"
,
""
)
require
.
Contains
(
t
,
url
,
"/sora/media-signed"
)
require
.
Contains
(
t
,
url
,
"expires="
)
require
.
Contains
(
t
,
url
,
"sig="
)
}
func
TestNormalizeSoraMediaURLs_Empty
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
result
:=
svc
.
normalizeSoraMediaURLs
(
nil
)
require
.
Empty
(
t
,
result
)
result
=
svc
.
normalizeSoraMediaURLs
([]
string
{})
require
.
Empty
(
t
,
result
)
}
func
TestNormalizeSoraMediaURLs_HTTPUrls
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
urls
:=
[]
string
{
"https://example.com/a.png"
,
"http://example.com/b.mp4"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Equal
(
t
,
urls
,
result
)
}
func
TestNormalizeSoraMediaURLs_LocalPaths
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
urls
:=
[]
string
{
"/image/2025/01/a.png"
,
"video/2025/01/b.mp4"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Len
(
t
,
result
,
2
)
require
.
Contains
(
t
,
result
[
0
],
"/sora/media"
)
require
.
Contains
(
t
,
result
[
1
],
"/sora/media"
)
}
func
TestNormalizeSoraMediaURLs_SkipsBlank
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
urls
:=
[]
string
{
"https://example.com/a.png"
,
""
,
" "
,
"https://example.com/b.png"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Len
(
t
,
result
,
2
)
}
func
TestBuildSoraContent_Image
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"image"
,
[]
string
{
"https://a.com/1.png"
,
"https://a.com/2.png"
})
require
.
Contains
(
t
,
content
,
""
)
require
.
Contains
(
t
,
content
,
""
)
}
func
TestBuildSoraContent_Video
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"video"
,
[]
string
{
"https://a.com/v.mp4"
})
require
.
Contains
(
t
,
content
,
"<video src='https://a.com/v.mp4'"
)
}
func
TestBuildSoraContent_VideoEmpty
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"video"
,
nil
)
require
.
Empty
(
t
,
content
)
}
func
TestBuildSoraContent_Prompt
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"prompt"
,
nil
)
require
.
Empty
(
t
,
content
)
}
func
TestSoraImageSizeFromModel
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"360"
,
soraImageSizeFromModel
(
"gpt-image"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"gpt-image-landscape"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"gpt-image-portrait"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"something-landscape"
))
require
.
Equal
(
t
,
"360"
,
soraImageSizeFromModel
(
"unknown-model"
))
}
func
TestFirstMediaURL
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
firstMediaURL
(
nil
))
require
.
Equal
(
t
,
""
,
firstMediaURL
([]
string
{}))
require
.
Equal
(
t
,
"a"
,
firstMediaURL
([]
string
{
"a"
,
"b"
}))
}
func
TestSoraProErrorMessage
(
t
*
testing
.
T
)
{
require
.
Contains
(
t
,
soraProErrorMessage
(
"sora2pro-hd"
,
""
),
"Pro-HD"
)
require
.
Contains
(
t
,
soraProErrorMessage
(
"sora2pro"
,
""
),
"Pro"
)
require
.
Empty
(
t
,
soraProErrorMessage
(
"sora-basic"
,
""
))
}
func
TestSoraGatewayService_WriteSoraError_StreamEscapesJSON
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
svc
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"invalid
\"
prompt
\"\n
line2"
,
true
)
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: error
\n
"
)
require
.
Contains
(
t
,
body
,
"data: [DONE]
\n\n
"
)
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
require
.
GreaterOrEqual
(
t
,
len
(
lines
),
2
)
require
.
Equal
(
t
,
"event: error"
,
lines
[
0
])
require
.
True
(
t
,
strings
.
HasPrefix
(
lines
[
1
],
"data: "
))
data
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
data
),
&
parsed
))
errObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errObj
[
"type"
])
require
.
Equal
(
t
,
"invalid
\"
prompt
\"\n
line2"
,
errObj
[
"message"
])
}
func
TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
sourceHeaders
:=
http
.
Header
{}
sourceHeaders
.
Set
(
"cf-ray"
,
"9d01b0e9ecc35829-SEA"
)
err
:=
svc
.
handleSoraRequestError
(
context
.
Background
(),
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
},
&
SoraUpstreamError
{
StatusCode
:
http
.
StatusForbidden
,
Message
:
"forbidden"
,
Headers
:
sourceHeaders
,
Body
:
[]
byte
(
`<!DOCTYPE html><title>Just a moment...</title>`
),
},
"sora2-landscape-10s"
,
nil
,
false
,
)
var
failoverErr
*
UpstreamFailoverError
require
.
ErrorAs
(
t
,
err
,
&
failoverErr
)
require
.
NotNil
(
t
,
failoverErr
.
ResponseHeaders
)
require
.
Equal
(
t
,
"9d01b0e9ecc35829-SEA"
,
failoverErr
.
ResponseHeaders
.
Get
(
"cf-ray"
))
sourceHeaders
.
Set
(
"cf-ray"
,
"mutated-after-return"
)
require
.
Equal
(
t
,
"9d01b0e9ecc35829-SEA"
,
failoverErr
.
ResponseHeaders
.
Get
(
"cf-ray"
))
}
func
TestShouldFailoverUpstreamError
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
401
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
404
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
429
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
500
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
502
))
require
.
False
(
t
,
svc
.
shouldFailoverUpstreamError
(
200
))
require
.
False
(
t
,
svc
.
shouldFailoverUpstreamError
(
400
))
}
func
TestWithSoraTimeout_NilService
(
t
*
testing
.
T
)
{
var
svc
*
SoraGatewayService
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
Nil
(
t
,
cancel
)
}
func
TestWithSoraTimeout_ZeroTimeout
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
Nil
(
t
,
cancel
)
}
func
TestWithSoraTimeout_PositiveTimeout
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
SoraRequestTimeoutSeconds
:
30
,
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
NotNil
(
t
,
cancel
)
cancel
()
}
func
TestPollInterval
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
5
,
},
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
require
.
Equal
(
t
,
5
*
time
.
Second
,
svc
.
pollInterval
())
// 默认值
svc2
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc2
.
pollInterval
()
>
0
)
}
func
TestPollMaxAttempts
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
MaxPollAttempts
:
100
,
},
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
require
.
Equal
(
t
,
100
,
svc
.
pollMaxAttempts
())
// 默认值
svc2
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc2
.
pollMaxAttempts
()
>
0
)
}
func
TestDecodeSoraImageInput_BlockPrivateURL
(
t
*
testing
.
T
)
{
_
,
_
,
err
:=
decodeSoraImageInput
(
context
.
Background
(),
"http://127.0.0.1/internal.png"
)
require
.
Error
(
t
,
err
)
}
func
TestDecodeSoraImageInput_DataURL
(
t
*
testing
.
T
)
{
encoded
:=
"data:image/png;base64,aGVsbG8="
data
,
filename
,
err
:=
decodeSoraImageInput
(
context
.
Background
(),
encoded
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
data
)
require
.
Contains
(
t
,
filename
,
".png"
)
}
func
TestDecodeBase64WithLimit_ExceedLimit
(
t
*
testing
.
T
)
{
data
,
err
:=
decodeBase64WithLimit
(
"aGVsbG8="
,
3
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
data
)
}
func
TestParseSoraWatermarkOptions_NumericBool
(
t
*
testing
.
T
)
{
body
:=
map
[
string
]
any
{
"watermark_free"
:
float64
(
1
),
"watermark_fallback_on_failure"
:
float64
(
0
),
}
opts
:=
parseSoraWatermarkOptions
(
body
)
require
.
True
(
t
,
opts
.
Enabled
)
require
.
False
(
t
,
opts
.
FallbackOnFailure
)
}
func
TestParseSoraVideoCount
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
1
,
parseSoraVideoCount
(
nil
))
require
.
Equal
(
t
,
2
,
parseSoraVideoCount
(
map
[
string
]
any
{
"video_count"
:
float64
(
2
)}))
require
.
Equal
(
t
,
3
,
parseSoraVideoCount
(
map
[
string
]
any
{
"videos"
:
"5"
}))
require
.
Equal
(
t
,
1
,
parseSoraVideoCount
(
map
[
string
]
any
{
"n_variants"
:
0
}))
}
backend/internal/service/sora_gateway_streaming_legacy.go
deleted
100644 → 0
View file @
dbb248df
//nolint:unused
package
service
import
(
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var
soraSSEDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
var
soraImageMarkdownRe
=
regexp
.
MustCompile
(
`!\[[^\]]*\]\(([^)]+)\)`
)
var
soraVideoHTMLRe
=
regexp
.
MustCompile
(
`(?i)<video[^>]+src=['"]([^'"]+)['"]`
)
const
soraRewriteBufferLimit
=
2048
type
soraStreamingResult
struct
{
mediaType
string
mediaURLs
[]
string
imageCount
int
imageSize
string
firstTokenMs
*
int
}
func
(
s
*
SoraGatewayService
)
setUpstreamRequestError
(
c
*
gin
.
Context
,
account
*
Account
,
err
error
)
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
}
}
func
(
s
*
SoraGatewayService
)
handleFailoverSideEffects
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
account
*
Account
)
{
if
s
.
rateLimitService
==
nil
||
account
==
nil
||
resp
==
nil
{
return
}
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
func
(
s
*
SoraGatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
reqModel
string
)
(
*
ForwardResult
,
error
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
msg
:=
soraProErrorMessage
(
reqModel
,
upstreamMsg
);
msg
!=
""
{
upstreamMsg
=
msg
}
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
c
!=
nil
{
responsePayload
:=
s
.
buildErrorPayload
(
respBody
,
upstreamMsg
)
c
.
JSON
(
resp
.
StatusCode
,
responsePayload
)
}
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
func
(
s
*
SoraGatewayService
)
buildErrorPayload
(
respBody
[]
byte
,
overrideMessage
string
)
map
[
string
]
any
{
if
len
(
respBody
)
>
0
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
respBody
,
&
payload
);
err
==
nil
{
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
overrideMessage
!=
""
{
errObj
[
"message"
]
=
overrideMessage
}
payload
[
"error"
]
=
errObj
return
payload
}
}
}
return
map
[
string
]
any
{
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
overrideMessage
,
},
}
}
func
(
s
*
SoraGatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
string
,
clientStream
bool
)
(
*
soraStreamingResult
,
error
)
{
if
resp
==
nil
{
return
nil
,
errors
.
New
(
"empty response"
)
}
if
clientStream
{
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
v
:=
resp
.
Header
.
Get
(
"x-request-id"
);
v
!=
""
{
c
.
Header
(
"x-request-id"
,
v
)
}
}
w
:=
c
.
Writer
flusher
,
_
:=
w
.
(
http
.
Flusher
)
contentBuilder
:=
strings
.
Builder
{}
var
firstTokenMs
*
int
var
upstreamError
error
rewriteBuffer
:=
""
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
sendLine
:=
func
(
line
string
)
error
{
if
!
clientStream
{
return
nil
}
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
nil
}
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
if
soraSSEDataRe
.
MatchString
(
line
)
{
data
:=
soraSSEDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
"[DONE]"
{
if
rewriteBuffer
!=
""
{
flushLine
,
flushContent
,
err
:=
s
.
flushSoraRewriteBuffer
(
rewriteBuffer
,
originalModel
)
if
err
!=
nil
{
return
nil
,
err
}
if
flushLine
!=
""
{
if
flushContent
!=
""
{
if
_
,
err
:=
contentBuilder
.
WriteString
(
flushContent
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
sendLine
(
flushLine
);
err
!=
nil
{
return
nil
,
err
}
}
rewriteBuffer
=
""
}
if
err
:=
sendLine
(
"data: [DONE]"
);
err
!=
nil
{
return
nil
,
err
}
break
}
updatedLine
,
contentDelta
,
errEvent
:=
s
.
processSoraSSEData
(
data
,
originalModel
,
&
rewriteBuffer
)
if
errEvent
!=
nil
&&
upstreamError
==
nil
{
upstreamError
=
errEvent
}
if
contentDelta
!=
""
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
if
_
,
err
:=
contentBuilder
.
WriteString
(
contentDelta
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
sendLine
(
updatedLine
);
err
!=
nil
{
return
nil
,
err
}
continue
}
if
err
:=
sendLine
(
line
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
if
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
if
clientStream
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
response_too_large
\"
}
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
}
return
nil
,
err
}
if
ctx
.
Err
()
==
context
.
DeadlineExceeded
&&
s
.
rateLimitService
!=
nil
&&
account
!=
nil
{
s
.
rateLimitService
.
HandleStreamTimeout
(
ctx
,
account
,
originalModel
)
}
if
clientStream
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
stream_read_error
\"
}
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
}
return
nil
,
err
}
content
:=
contentBuilder
.
String
()
mediaType
,
mediaURLs
:=
s
.
extractSoraMedia
(
content
)
if
mediaType
==
""
&&
isSoraPromptEnhanceModel
(
originalModel
)
{
mediaType
=
"prompt"
}
imageSize
:=
""
imageCount
:=
0
if
mediaType
==
"image"
{
imageSize
=
soraImageSizeFromModel
(
originalModel
)
imageCount
=
len
(
mediaURLs
)
}
if
upstreamError
!=
nil
&&
!
clientStream
{
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusBadGateway
,
map
[
string
]
any
{
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
upstreamError
.
Error
(),
},
})
}
return
nil
,
upstreamError
}
if
!
clientStream
{
response
:=
buildSoraNonStreamResponse
(
content
,
originalModel
)
if
len
(
mediaURLs
)
>
0
{
response
[
"media_url"
]
=
mediaURLs
[
0
]
if
len
(
mediaURLs
)
>
1
{
response
[
"media_urls"
]
=
mediaURLs
}
}
c
.
JSON
(
http
.
StatusOK
,
response
)
}
return
&
soraStreamingResult
{
mediaType
:
mediaType
,
mediaURLs
:
mediaURLs
,
imageCount
:
imageCount
,
imageSize
:
imageSize
,
firstTokenMs
:
firstTokenMs
,
},
nil
}
func
(
s
*
SoraGatewayService
)
processSoraSSEData
(
data
string
,
originalModel
string
,
rewriteBuffer
*
string
)
(
string
,
string
,
error
)
{
if
strings
.
TrimSpace
(
data
)
==
""
{
return
"data: "
,
""
,
nil
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
payload
);
err
!=
nil
{
return
"data: "
+
data
,
""
,
nil
}
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
return
"data: "
+
data
,
""
,
errors
.
New
(
msg
)
}
}
if
model
,
ok
:=
payload
[
"model"
]
.
(
string
);
ok
&&
model
!=
""
&&
originalModel
!=
""
{
payload
[
"model"
]
=
originalModel
}
contentDelta
,
updated
:=
extractSoraContent
(
payload
)
if
updated
{
var
rewritten
string
if
rewriteBuffer
!=
nil
{
rewritten
=
s
.
rewriteSoraContentWithBuffer
(
contentDelta
,
rewriteBuffer
)
}
else
{
rewritten
=
s
.
rewriteSoraContent
(
contentDelta
)
}
if
rewritten
!=
contentDelta
{
applySoraContent
(
payload
,
rewritten
)
contentDelta
=
rewritten
}
}
updatedData
,
err
:=
jsonMarshalRaw
(
payload
)
if
err
!=
nil
{
return
"data: "
+
data
,
contentDelta
,
nil
}
return
"data: "
+
string
(
updatedData
),
contentDelta
,
nil
}
func
extractSoraContent
(
payload
map
[
string
]
any
)
(
string
,
bool
)
{
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
return
""
,
false
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
""
,
false
}
if
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
delta
[
"content"
]
.
(
string
);
ok
{
return
content
,
true
}
}
if
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
message
[
"content"
]
.
(
string
);
ok
{
return
content
,
true
}
}
return
""
,
false
}
func
applySoraContent
(
payload
map
[
string
]
any
,
content
string
)
{
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
return
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
if
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
delta
[
"content"
]
=
content
choice
[
"delta"
]
=
delta
return
}
if
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
message
[
"content"
]
=
content
choice
[
"message"
]
=
message
}
}
func
(
s
*
SoraGatewayService
)
rewriteSoraContentWithBuffer
(
contentDelta
string
,
buffer
*
string
)
string
{
if
buffer
==
nil
{
return
s
.
rewriteSoraContent
(
contentDelta
)
}
if
contentDelta
==
""
&&
*
buffer
==
""
{
return
""
}
combined
:=
*
buffer
+
contentDelta
rewritten
:=
s
.
rewriteSoraContent
(
combined
)
bufferStart
:=
s
.
findSoraRewriteBufferStart
(
rewritten
)
if
bufferStart
<
0
{
*
buffer
=
""
return
rewritten
}
if
len
(
rewritten
)
-
bufferStart
>
soraRewriteBufferLimit
{
bufferStart
=
len
(
rewritten
)
-
soraRewriteBufferLimit
}
output
:=
rewritten
[
:
bufferStart
]
*
buffer
=
rewritten
[
bufferStart
:
]
return
output
}
func
(
s
*
SoraGatewayService
)
findSoraRewriteBufferStart
(
content
string
)
int
{
minIndex
:=
-
1
start
:=
0
for
{
idx
:=
strings
.
Index
(
content
[
start
:
],
"!["
)
if
idx
<
0
{
break
}
idx
+=
start
if
!
hasSoraImageMatchAt
(
content
,
idx
)
{
if
minIndex
==
-
1
||
idx
<
minIndex
{
minIndex
=
idx
}
}
start
=
idx
+
2
}
lower
:=
strings
.
ToLower
(
content
)
start
=
0
for
{
idx
:=
strings
.
Index
(
lower
[
start
:
],
"<video"
)
if
idx
<
0
{
break
}
idx
+=
start
if
!
hasSoraVideoMatchAt
(
content
,
idx
)
{
if
minIndex
==
-
1
||
idx
<
minIndex
{
minIndex
=
idx
}
}
start
=
idx
+
len
(
"<video"
)
}
return
minIndex
}
func
hasSoraImageMatchAt
(
content
string
,
idx
int
)
bool
{
if
idx
<
0
||
idx
>=
len
(
content
)
{
return
false
}
loc
:=
soraImageMarkdownRe
.
FindStringIndex
(
content
[
idx
:
])
return
loc
!=
nil
&&
loc
[
0
]
==
0
}
func
hasSoraVideoMatchAt
(
content
string
,
idx
int
)
bool
{
if
idx
<
0
||
idx
>=
len
(
content
)
{
return
false
}
loc
:=
soraVideoHTMLRe
.
FindStringIndex
(
content
[
idx
:
])
return
loc
!=
nil
&&
loc
[
0
]
==
0
}
func
(
s
*
SoraGatewayService
)
rewriteSoraContent
(
content
string
)
string
{
if
content
==
""
{
return
content
}
content
=
soraImageMarkdownRe
.
ReplaceAllStringFunc
(
content
,
func
(
match
string
)
string
{
sub
:=
soraImageMarkdownRe
.
FindStringSubmatch
(
match
)
if
len
(
sub
)
<
2
{
return
match
}
rewritten
:=
s
.
rewriteSoraURL
(
sub
[
1
])
if
rewritten
==
sub
[
1
]
{
return
match
}
return
strings
.
Replace
(
match
,
sub
[
1
],
rewritten
,
1
)
})
content
=
soraVideoHTMLRe
.
ReplaceAllStringFunc
(
content
,
func
(
match
string
)
string
{
sub
:=
soraVideoHTMLRe
.
FindStringSubmatch
(
match
)
if
len
(
sub
)
<
2
{
return
match
}
rewritten
:=
s
.
rewriteSoraURL
(
sub
[
1
])
if
rewritten
==
sub
[
1
]
{
return
match
}
return
strings
.
Replace
(
match
,
sub
[
1
],
rewritten
,
1
)
})
return
content
}
func
(
s
*
SoraGatewayService
)
flushSoraRewriteBuffer
(
buffer
string
,
originalModel
string
)
(
string
,
string
,
error
)
{
if
buffer
==
""
{
return
""
,
""
,
nil
}
rewritten
:=
s
.
rewriteSoraContent
(
buffer
)
payload
:=
map
[
string
]
any
{
"choices"
:
[]
any
{
map
[
string
]
any
{
"delta"
:
map
[
string
]
any
{
"content"
:
rewritten
,
},
"index"
:
0
,
},
},
}
if
originalModel
!=
""
{
payload
[
"model"
]
=
originalModel
}
updatedData
,
err
:=
jsonMarshalRaw
(
payload
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
"data: "
+
string
(
updatedData
),
rewritten
,
nil
}
func
(
s
*
SoraGatewayService
)
rewriteSoraURL
(
raw
string
)
string
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
raw
}
path
:=
parsed
.
Path
if
!
strings
.
HasPrefix
(
path
,
"/tmp/"
)
&&
!
strings
.
HasPrefix
(
path
,
"/static/"
)
{
return
raw
}
return
s
.
buildSoraMediaURL
(
path
,
parsed
.
RawQuery
)
}
func
(
s
*
SoraGatewayService
)
extractSoraMedia
(
content
string
)
(
string
,
[]
string
)
{
if
content
==
""
{
return
""
,
nil
}
if
match
:=
soraVideoHTMLRe
.
FindStringSubmatch
(
content
);
len
(
match
)
>
1
{
return
"video"
,
[]
string
{
match
[
1
]}
}
imageMatches
:=
soraImageMarkdownRe
.
FindAllStringSubmatch
(
content
,
-
1
)
if
len
(
imageMatches
)
==
0
{
return
""
,
nil
}
urls
:=
make
([]
string
,
0
,
len
(
imageMatches
))
for
_
,
match
:=
range
imageMatches
{
if
len
(
match
)
>
1
{
urls
=
append
(
urls
,
match
[
1
])
}
}
return
"image"
,
urls
}
func
isSoraPromptEnhanceModel
(
model
string
)
bool
{
return
strings
.
HasPrefix
(
strings
.
ToLower
(
strings
.
TrimSpace
(
model
)),
"prompt-enhance"
)
}
backend/internal/service/sora_generation.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type
SoraGeneration
struct
{
ID
int64
`json:"id"`
UserID
int64
`json:"user_id"`
APIKeyID
*
int64
`json:"api_key_id,omitempty"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
MediaType
string
`json:"media_type"`
// video / image
Status
string
`json:"status"`
// pending / generating / completed / failed / cancelled
MediaURL
string
`json:"media_url"`
// 主媒体 URL(预签名或 CDN)
MediaURLs
[]
string
`json:"media_urls"`
// 多图时的 URL 数组
FileSizeBytes
int64
`json:"file_size_bytes"`
StorageType
string
`json:"storage_type"`
// s3 / local / upstream / none
S3ObjectKeys
[]
string
`json:"s3_object_keys"`
// S3 object key 数组
UpstreamTaskID
string
`json:"upstream_task_id"`
ErrorMessage
string
`json:"error_message"`
CreatedAt
time
.
Time
`json:"created_at"`
CompletedAt
*
time
.
Time
`json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const
(
SoraGenStatusPending
=
"pending"
SoraGenStatusGenerating
=
"generating"
SoraGenStatusCompleted
=
"completed"
SoraGenStatusFailed
=
"failed"
SoraGenStatusCancelled
=
"cancelled"
)
// Sora 存储类型常量
const
(
SoraStorageTypeS3
=
"s3"
SoraStorageTypeLocal
=
"local"
SoraStorageTypeUpstream
=
"upstream"
SoraStorageTypeNone
=
"none"
)
// SoraGenerationListParams 查询生成记录的参数。
type
SoraGenerationListParams
struct
{
UserID
int64
Status
string
// 可选筛选
StorageType
string
// 可选筛选
MediaType
string
// 可选筛选
Page
int
PageSize
int
}
// SoraGenerationRepository 生成记录持久化接口。
type
SoraGenerationRepository
interface
{
Create
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
SoraGeneration
,
error
)
Update
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
CountByUserAndStatus
(
ctx
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
}
backend/internal/service/sora_generation_service.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var
(
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit
=
errors
.
New
(
"sora generation concurrent limit exceeded"
)
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict
=
errors
.
New
(
"sora generation state conflict"
)
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive
=
errors
.
New
(
"sora generation is not active"
)
)
const
soraGenerationActiveLimit
=
3
type
soraGenerationRepoAtomicCreator
interface
{
CreatePendingWithLimit
(
ctx
context
.
Context
,
gen
*
SoraGeneration
,
activeStatuses
[]
string
,
maxActive
int64
)
error
}
type
soraGenerationRepoConditionalUpdater
interface
{
UpdateGeneratingIfPending
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
(
bool
,
error
)
UpdateCompletedIfActive
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateFailedIfActive
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateCancelledIfActive
(
ctx
context
.
Context
,
id
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateStorageIfCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
)
(
bool
,
error
)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type
SoraGenerationService
struct
{
genRepo
SoraGenerationRepository
s3Storage
*
SoraS3Storage
quotaService
*
SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func
NewSoraGenerationService
(
genRepo
SoraGenerationRepository
,
s3Storage
*
SoraS3Storage
,
quotaService
*
SoraQuotaService
,
)
*
SoraGenerationService
{
return
&
SoraGenerationService
{
genRepo
:
genRepo
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func
(
s
*
SoraGenerationService
)
CreatePending
(
ctx
context
.
Context
,
userID
int64
,
apiKeyID
*
int64
,
model
,
prompt
,
mediaType
string
)
(
*
SoraGeneration
,
error
)
{
gen
:=
&
SoraGeneration
{
UserID
:
userID
,
APIKeyID
:
apiKeyID
,
Model
:
model
,
Prompt
:
prompt
,
MediaType
:
mediaType
,
Status
:
SoraGenStatusPending
,
StorageType
:
SoraStorageTypeNone
,
}
if
atomicCreator
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoAtomicCreator
);
ok
{
if
err
:=
atomicCreator
.
CreatePendingWithLimit
(
ctx
,
gen
,
[]
string
{
SoraGenStatusPending
,
SoraGenStatusGenerating
},
soraGenerationActiveLimit
,
);
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSoraGenerationConcurrencyLimit
)
{
return
nil
,
err
}
return
nil
,
fmt
.
Errorf
(
"create generation: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 创建记录 id=%d user=%d model=%s"
,
gen
.
ID
,
userID
,
model
)
return
gen
,
nil
}
if
err
:=
s
.
genRepo
.
Create
(
ctx
,
gen
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create generation: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 创建记录 id=%d user=%d model=%s"
,
gen
.
ID
,
userID
,
model
)
return
gen
,
nil
}
// MarkGenerating 标记为生成中。
func
(
s
*
SoraGenerationService
)
MarkGenerating
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
error
{
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateGeneratingIfPending
(
ctx
,
id
,
upstreamTaskID
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusGenerating
gen
.
UpstreamTaskID
=
upstreamTaskID
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkCompleted 标记为已完成。
func
(
s
*
SoraGenerationService
)
MarkCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateCompletedIfActive
(
ctx
,
id
,
mediaURL
,
mediaURLs
,
storageType
,
s3Keys
,
fileSizeBytes
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusCompleted
gen
.
MediaURL
=
mediaURL
gen
.
MediaURLs
=
mediaURLs
gen
.
StorageType
=
storageType
gen
.
S3ObjectKeys
=
s3Keys
gen
.
FileSizeBytes
=
fileSizeBytes
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkFailed 标记为失败。
func
(
s
*
SoraGenerationService
)
MarkFailed
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateFailedIfActive
(
ctx
,
id
,
errMsg
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusFailed
gen
.
ErrorMessage
=
errMsg
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkCancelled 标记为已取消。
func
(
s
*
SoraGenerationService
)
MarkCancelled
(
ctx
context
.
Context
,
id
int64
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateCancelledIfActive
(
ctx
,
id
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationNotActive
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationNotActive
}
gen
.
Status
=
SoraGenStatusCancelled
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
func
(
s
*
SoraGenerationService
)
UpdateStorageForCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
)
error
{
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateStorageIfCompleted
(
ctx
,
id
,
mediaURL
,
mediaURLs
,
storageType
,
s3Keys
,
fileSizeBytes
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusCompleted
{
return
ErrSoraGenerationStateConflict
}
gen
.
MediaURL
=
mediaURL
gen
.
MediaURLs
=
mediaURLs
gen
.
StorageType
=
storageType
gen
.
S3ObjectKeys
=
s3Keys
gen
.
FileSizeBytes
=
fileSizeBytes
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// GetByID 获取记录详情(含权限校验)。
func
(
s
*
SoraGenerationService
)
GetByID
(
ctx
context
.
Context
,
id
,
userID
int64
)
(
*
SoraGeneration
,
error
)
{
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
gen
.
UserID
!=
userID
{
return
nil
,
fmt
.
Errorf
(
"无权访问此生成记录"
)
}
return
gen
,
nil
}
// List 查询生成记录列表(分页 + 筛选)。
func
(
s
*
SoraGenerationService
)
List
(
ctx
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
{
if
params
.
Page
<=
0
{
params
.
Page
=
1
}
if
params
.
PageSize
<=
0
{
params
.
PageSize
=
20
}
if
params
.
PageSize
>
100
{
params
.
PageSize
=
100
}
return
s
.
genRepo
.
List
(
ctx
,
params
)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func
(
s
*
SoraGenerationService
)
Delete
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
UserID
!=
userID
{
return
fmt
.
Errorf
(
"无权删除此生成记录"
)
}
// 清理 S3 文件
if
gen
.
StorageType
==
SoraStorageTypeS3
&&
len
(
gen
.
S3ObjectKeys
)
>
0
&&
s
.
s3Storage
!=
nil
{
if
err
:=
s
.
s3Storage
.
DeleteObjects
(
ctx
,
gen
.
S3ObjectKeys
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] S3 清理失败 id=%d err=%v"
,
id
,
err
)
}
}
// 释放配额(S3/本地均释放)
if
gen
.
FileSizeBytes
>
0
&&
(
gen
.
StorageType
==
SoraStorageTypeS3
||
gen
.
StorageType
==
SoraStorageTypeLocal
)
&&
s
.
quotaService
!=
nil
{
if
err
:=
s
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
gen
.
FileSizeBytes
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 配额释放失败 id=%d err=%v"
,
id
,
err
)
}
}
return
s
.
genRepo
.
Delete
(
ctx
,
id
)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func
(
s
*
SoraGenerationService
)
CountActiveByUser
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
return
s
.
genRepo
.
CountByUserAndStatus
(
ctx
,
userID
,
[]
string
{
SoraGenStatusPending
,
SoraGenStatusGenerating
})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func
(
s
*
SoraGenerationService
)
ResolveMediaURLs
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
gen
==
nil
||
gen
.
StorageType
!=
SoraStorageTypeS3
||
s
.
s3Storage
==
nil
{
return
nil
}
if
len
(
gen
.
S3ObjectKeys
)
==
0
{
return
nil
}
urls
:=
make
([]
string
,
len
(
gen
.
S3ObjectKeys
))
var
wg
sync
.
WaitGroup
var
firstErr
error
var
errMu
sync
.
Mutex
for
idx
,
key
:=
range
gen
.
S3ObjectKeys
{
wg
.
Add
(
1
)
go
func
(
i
int
,
objectKey
string
)
{
defer
wg
.
Done
()
url
,
err
:=
s
.
s3Storage
.
GetAccessURL
(
ctx
,
objectKey
)
if
err
!=
nil
{
errMu
.
Lock
()
if
firstErr
==
nil
{
firstErr
=
err
}
errMu
.
Unlock
()
return
}
urls
[
i
]
=
url
}(
idx
,
key
)
}
wg
.
Wait
()
if
firstErr
!=
nil
{
return
firstErr
}
gen
.
MediaURL
=
urls
[
0
]
gen
.
MediaURLs
=
urls
return
nil
}
backend/internal/service/sora_generation_service_test.go
deleted
100644 → 0
View file @
dbb248df
//go:build unit
package
service
import
(
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// ==================== Stub: SoraGenerationRepository ====================
var
_
SoraGenerationRepository
=
(
*
stubGenRepo
)(
nil
)
type
stubGenRepo
struct
{
gens
map
[
int64
]
*
SoraGeneration
nextID
int64
createErr
error
getErr
error
updateErr
error
deleteErr
error
listErr
error
countErr
error
countValue
int64
}
func
newStubGenRepo
()
*
stubGenRepo
{
return
&
stubGenRepo
{
gens
:
make
(
map
[
int64
]
*
SoraGeneration
),
nextID
:
1
}
}
func
(
r
*
stubGenRepo
)
Create
(
_
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
r
.
createErr
!=
nil
{
return
r
.
createErr
}
gen
.
ID
=
r
.
nextID
gen
.
CreatedAt
=
time
.
Now
()
r
.
nextID
++
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubGenRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
SoraGeneration
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
if
gen
,
ok
:=
r
.
gens
[
id
];
ok
{
return
gen
,
nil
}
return
nil
,
fmt
.
Errorf
(
"not found"
)
}
func
(
r
*
stubGenRepo
)
Update
(
_
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubGenRepo
)
Delete
(
_
context
.
Context
,
id
int64
)
error
{
if
r
.
deleteErr
!=
nil
{
return
r
.
deleteErr
}
delete
(
r
.
gens
,
id
)
return
nil
}
func
(
r
*
stubGenRepo
)
List
(
_
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
{
if
r
.
listErr
!=
nil
{
return
nil
,
0
,
r
.
listErr
}
var
result
[]
*
SoraGeneration
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
!=
params
.
UserID
{
continue
}
if
params
.
Status
!=
""
&&
gen
.
Status
!=
params
.
Status
{
continue
}
if
params
.
StorageType
!=
""
&&
gen
.
StorageType
!=
params
.
StorageType
{
continue
}
if
params
.
MediaType
!=
""
&&
gen
.
MediaType
!=
params
.
MediaType
{
continue
}
result
=
append
(
result
,
gen
)
}
return
result
,
int64
(
len
(
result
)),
nil
}
func
(
r
*
stubGenRepo
)
CountByUserAndStatus
(
_
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
{
if
r
.
countErr
!=
nil
{
return
0
,
r
.
countErr
}
if
r
.
countValue
>
0
{
return
r
.
countValue
,
nil
}
var
count
int64
statusSet
:=
make
(
map
[
string
]
struct
{})
for
_
,
s
:=
range
statuses
{
statusSet
[
s
]
=
struct
{}{}
}
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
==
userID
{
if
_
,
ok
:=
statusSet
[
gen
.
Status
];
ok
{
count
++
}
}
}
return
count
,
nil
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var
_
UserRepository
=
(
*
stubUserRepoForQuota
)(
nil
)
type
stubUserRepoForQuota
struct
{
users
map
[
int64
]
*
User
updateErr
error
}
func
newStubUserRepoForQuota
()
*
stubUserRepoForQuota
{
return
&
stubUserRepoForQuota
{
users
:
make
(
map
[
int64
]
*
User
)}
}
func
(
r
*
stubUserRepoForQuota
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
User
,
error
)
{
if
u
,
ok
:=
r
.
users
[
id
];
ok
{
return
u
,
nil
}
return
nil
,
fmt
.
Errorf
(
"user not found"
)
}
func
(
r
*
stubUserRepoForQuota
)
Update
(
_
context
.
Context
,
user
*
User
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
users
[
user
.
ID
]
=
user
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
Create
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
GetByEmail
(
context
.
Context
,
string
)
(
*
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
GetFirstAdmin
(
context
.
Context
)
(
*
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
ExistsByEmail
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
func
newS3StorageWithCDN
(
cdnURL
string
)
*
SoraS3Storage
{
storage
:=
&
SoraS3Storage
{}
storage
.
cfg
=
&
SoraS3Settings
{
Enabled
:
true
,
Bucket
:
"test-bucket"
,
CDNURL
:
cdnURL
,
}
// 需要 non-nil client 使 getClient 命中缓存
storage
.
client
=
s3
.
New
(
s3
.
Options
{})
return
storage
}
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
func
newS3StorageFailingDelete
()
*
SoraS3Storage
{
return
&
SoraS3Storage
{}
// settingService 为 nil → getConfig 返回 error
}
// ==================== CreatePending ====================
func
TestCreatePending_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"一只猫跳舞"
,
"video"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
UserID
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
gen
.
Model
)
require
.
Equal
(
t
,
"一只猫跳舞"
,
gen
.
Prompt
)
require
.
Equal
(
t
,
"video"
,
gen
.
MediaType
)
require
.
Equal
(
t
,
SoraGenStatusPending
,
gen
.
Status
)
require
.
Equal
(
t
,
SoraStorageTypeNone
,
gen
.
StorageType
)
require
.
Nil
(
t
,
gen
.
APIKeyID
)
}
func
TestCreatePending_WithAPIKeyID
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyID
:=
int64
(
42
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
&
apiKeyID
,
"gpt-image"
,
"画一朵花"
,
"image"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
gen
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
gen
.
APIKeyID
)
}
func
TestCreatePending_RepoError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
createErr
=
fmt
.
Errorf
(
"db write error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
require
.
Contains
(
t
,
err
.
Error
(),
"create generation"
)
}
// ==================== MarkGenerating ====================
func
TestMarkGenerating_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
1
,
"upstream-task-123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusGenerating
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
"upstream-task-123"
,
repo
.
gens
[
1
]
.
UpstreamTaskID
)
}
func
TestMarkGenerating_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
999
,
""
)
require
.
Error
(
t
,
err
)
}
func
TestMarkGenerating_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
1
,
""
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkCompleted ====================
func
TestMarkCompleted_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
1
,
"https://cdn.example.com/video.mp4"
,
[]
string
{
"https://cdn.example.com/video.mp4"
},
SoraStorageTypeS3
,
[]
string
{
"sora/1/2024/01/01/uuid.mp4"
},
1048576
,
)
require
.
NoError
(
t
,
err
)
gen
:=
repo
.
gens
[
1
]
require
.
Equal
(
t
,
SoraGenStatusCompleted
,
gen
.
Status
)
require
.
Equal
(
t
,
"https://cdn.example.com/video.mp4"
,
gen
.
MediaURL
)
require
.
Equal
(
t
,
[]
string
{
"https://cdn.example.com/video.mp4"
},
gen
.
MediaURLs
)
require
.
Equal
(
t
,
SoraStorageTypeS3
,
gen
.
StorageType
)
require
.
Equal
(
t
,
[]
string
{
"sora/1/2024/01/01/uuid.mp4"
},
gen
.
S3ObjectKeys
)
require
.
Equal
(
t
,
int64
(
1048576
),
gen
.
FileSizeBytes
)
require
.
NotNil
(
t
,
gen
.
CompletedAt
)
}
func
TestMarkCompleted_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
999
,
""
,
nil
,
""
,
nil
,
0
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCompleted_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
1
,
"url"
,
nil
,
SoraStorageTypeUpstream
,
nil
,
0
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkFailed ====================
func
TestMarkFailed_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
1
,
"上游返回 500 错误"
)
require
.
NoError
(
t
,
err
)
gen
:=
repo
.
gens
[
1
]
require
.
Equal
(
t
,
SoraGenStatusFailed
,
gen
.
Status
)
require
.
Equal
(
t
,
"上游返回 500 错误"
,
gen
.
ErrorMessage
)
require
.
NotNil
(
t
,
gen
.
CompletedAt
)
}
func
TestMarkFailed_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
999
,
"error"
)
require
.
Error
(
t
,
err
)
}
func
TestMarkFailed_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
1
,
"err"
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkCancelled ====================
func
TestMarkCancelled_Pending
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
1
]
.
Status
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
CompletedAt
)
}
func
TestMarkCancelled_Generating
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestMarkCancelled_Completed
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrSoraGenerationNotActive
)
}
func
TestMarkCancelled_Failed
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusFailed
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_AlreadyCancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCancelled
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
999
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
// ==================== GetByID ====================
func
TestGetByID_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
Model
:
"sora2-landscape-10s"
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
ID
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
gen
.
Model
)
}
func
TestGetByID_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权访问"
)
}
func
TestGetByID_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
}
// ==================== List ====================
func
TestList_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
MediaType
:
"video"
}
repo
.
gens
[
2
]
=
&
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Status
:
SoraGenStatusPending
,
MediaType
:
"image"
}
repo
.
gens
[
3
]
=
&
SoraGeneration
{
ID
:
3
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
,
MediaType
:
"video"
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gens
,
total
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
,
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
gens
,
2
)
// 只有 userID=1 的
require
.
Equal
(
t
,
int64
(
2
),
total
)
}
func
TestList_DefaultPagination
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
})
require
.
NoError
(
t
,
err
)
}
func
TestList_MaxPageSize
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// pageSize > 100 → 应限制为 100
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
,
Page
:
1
,
PageSize
:
200
})
require
.
NoError
(
t
,
err
)
}
func
TestList_Error
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
listErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
})
require
.
Error
(
t
,
err
)
}
// ==================== Delete ====================
func
TestDelete_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
StorageType
:
SoraStorageTypeUpstream
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDelete_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权删除"
)
}
func
TestDelete_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
}
func
TestDelete_S3Cleanup_NilS3
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
}}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// s3Storage 为 nil,跳过清理
}
func
TestDelete_QuotaRelease_NilQuota
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
1024
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// quotaService 为 nil,跳过释放
}
func
TestDelete_NonS3NoCleanup
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeLocal
,
FileSizeBytes
:
1024
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
}
func
TestDelete_DeleteRepoError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeUpstream
}
repo
.
deleteErr
=
fmt
.
Errorf
(
"delete failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
}
// ==================== CountActiveByUser ====================
func
TestCountActiveByUser_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
gens
[
2
]
=
&
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
gens
[
3
]
=
&
SoraGeneration
{
ID
:
3
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
// 不算
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
count
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
count
)
}
func
TestCountActiveByUser_NoActive
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
count
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
count
)
}
func
TestCountActiveByUser_Error
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
countErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
_
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
// ==================== ResolveMediaURLs ====================
func
TestResolveMediaURLs_NilGen
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
nil
))
}
func
TestResolveMediaURLs_NonS3
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeUpstream
,
MediaURL
:
"https://original.com/v.mp4"
}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
require
.
Equal
(
t
,
"https://original.com/v.mp4"
,
gen
.
MediaURL
)
// 不变
}
func
TestResolveMediaURLs_S3NilStorage
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
}}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
}
func
TestResolveMediaURLs_Local
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeLocal
,
MediaURL
:
"/video/2024/01/01/file.mp4"
}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
require
.
Equal
(
t
,
"/video/2024/01/01/file.mp4"
,
gen
.
MediaURL
)
// 不变
}
// ==================== 状态流转完整测试 ====================
func
TestStatusTransition_PendingToCompletedFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 1. 创建 pending
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusPending
,
gen
.
Status
)
// 2. 标记 generating
err
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
"task-123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusGenerating
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
// 3. 标记 completed
err
=
svc
.
MarkCompleted
(
context
.
Background
(),
gen
.
ID
,
"https://s3.com/video.mp4"
,
nil
,
SoraStorageTypeS3
,
[]
string
{
"key"
},
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCompleted
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
func
TestStatusTransition_PendingToFailedFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
_
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
""
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
gen
.
ID
,
"上游超时"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusFailed
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
require
.
Equal
(
t
,
"上游超时"
,
repo
.
gens
[
gen
.
ID
]
.
ErrorMessage
)
}
func
TestStatusTransition_PendingToCancelledFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
gen
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
func
TestStatusTransition_GeneratingToCancelledFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
_
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
""
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
gen
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
// ==================== 权限隔离测试 ====================
func
TestUserIsolation_CannotAccessOthersRecord
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
// 用户 2 尝试访问用户 1 的记录
_
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
gen
.
ID
,
2
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权访问"
)
}
func
TestUserIsolation_CannotDeleteOthersRecord
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
err
:=
svc
.
Delete
(
context
.
Background
(),
gen
.
ID
,
2
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权删除"
)
}
// ==================== Delete: S3 清理 + 配额释放路径 ====================
func
TestDelete_S3Cleanup_WithS3Storage
(
t
*
testing
.
T
)
{
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
// 验证 Delete 仍然成功(S3 错误只是记录日志)
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/abc.mp4"
},
}
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
repo
,
s3Storage
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// S3 清理失败不影响删除
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDelete_QuotaRelease_WithQuotaService
(
t
*
testing
.
T
)
{
// 有配额服务时,删除 S3 类型记录会释放配额
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
1048576
,
// 1MB
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
2097152
}
// 2MB
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// 配额应被释放: 2MB - 1MB = 1MB
require
.
Equal
(
t
,
int64
(
1048576
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_S3Cleanup_And_QuotaRelease
(
t
*
testing
.
T
)
{
// S3 清理 + 配额释放同时触发
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
},
FileSizeBytes
:
512
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
repo
,
s3Storage
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
require
.
Equal
(
t
,
int64
(
512
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_QuotaRelease_LocalStorage
(
t
*
testing
.
T
)
{
// 本地存储同样需要释放配额
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeLocal
,
FileSizeBytes
:
1024
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
2048
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_QuotaRelease_ZeroFileSize
(
t
*
testing
.
T
)
{
// FileSizeBytes=0 跳过配额释放
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
0
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
func
TestResolveMediaURLs_S3_CDN_SingleKey
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/video.mp4"
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/video.mp4"
,
gen
.
MediaURL
)
}
func
TestResolveMediaURLs_S3_CDN_MultipleKeys
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com/"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/img1.png"
,
"sora/1/2024/01/01/img2.png"
,
"sora/1/2024/01/01/img3.png"
,
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
// 主 URL 更新为第一个 key 的 CDN URL
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img1.png"
,
gen
.
MediaURL
)
// 多图 URLs 全部更新
require
.
Len
(
t
,
gen
.
MediaURLs
,
3
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img1.png"
,
gen
.
MediaURLs
[
0
])
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img2.png"
,
gen
.
MediaURLs
[
1
])
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img3.png"
,
gen
.
MediaURLs
[
2
])
}
func
TestResolveMediaURLs_S3_EmptyKeys
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"original"
,
gen
.
MediaURL
)
// 不变
}
func
TestResolveMediaURLs_S3_GetAccessURL_Error
(
t
*
testing
.
T
)
{
// 使用无 settingService 的 S3 Storage,getClient 会失败
s3Storage
:=
newS3StorageFailingDelete
()
// 同样 GetAccessURL 也会失败
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/video.mp4"
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
Error
(
t
,
err
)
// GetAccessURL 失败应传播错误
}
func
TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond
(
t
*
testing
.
T
)
{
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/img1.png"
,
"sora/1/2024/01/01/img2.png"
,
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
Error
(
t
,
err
)
// 第一个 key 的 GetAccessURL 就会失败
}
backend/internal/service/sora_media_cleanup_service.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
var
soraCleanupCronParser
=
cron
.
NewParser
(
cron
.
Minute
|
cron
.
Hour
|
cron
.
Dom
|
cron
.
Month
|
cron
.
Dow
)
// SoraMediaCleanupService 定期清理本地媒体文件
type
SoraMediaCleanupService
struct
{
storage
*
SoraMediaStorage
cfg
*
config
.
Config
cron
*
cron
.
Cron
startOnce
sync
.
Once
stopOnce
sync
.
Once
}
func
NewSoraMediaCleanupService
(
storage
*
SoraMediaStorage
,
cfg
*
config
.
Config
)
*
SoraMediaCleanupService
{
return
&
SoraMediaCleanupService
{
storage
:
storage
,
cfg
:
cfg
,
}
}
func
(
s
*
SoraMediaCleanupService
)
Start
()
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
}
if
!
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
Enabled
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (disabled)"
)
return
}
if
s
.
storage
==
nil
||
!
s
.
storage
.
Enabled
()
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (storage disabled)"
)
return
}
s
.
startOnce
.
Do
(
func
()
{
schedule
:=
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
Schedule
)
if
schedule
==
""
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (empty schedule)"
)
return
}
loc
:=
time
.
Local
if
strings
.
TrimSpace
(
s
.
cfg
.
Timezone
)
!=
""
{
if
parsed
,
err
:=
time
.
LoadLocation
(
strings
.
TrimSpace
(
s
.
cfg
.
Timezone
));
err
==
nil
&&
parsed
!=
nil
{
loc
=
parsed
}
}
c
:=
cron
.
New
(
cron
.
WithParser
(
soraCleanupCronParser
),
cron
.
WithLocation
(
loc
))
if
_
,
err
:=
c
.
AddFunc
(
schedule
,
func
()
{
s
.
runCleanup
()
});
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (invalid schedule=%q): %v"
,
schedule
,
err
)
return
}
s
.
cron
=
c
s
.
cron
.
Start
()
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] started (schedule=%q tz=%s)"
,
schedule
,
loc
.
String
())
})
}
func
(
s
*
SoraMediaCleanupService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
if
s
.
cron
!=
nil
{
ctx
:=
s
.
cron
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
case
<-
time
.
After
(
3
*
time
.
Second
)
:
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] cron stop timed out"
)
}
}
})
}
func
(
s
*
SoraMediaCleanupService
)
runCleanup
()
{
if
s
.
cfg
==
nil
||
s
.
storage
==
nil
{
return
}
retention
:=
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
RetentionDays
if
retention
<=
0
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] skipped (retention_days=%d)"
,
retention
)
return
}
cutoff
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
retention
)
deleted
:=
0
roots
:=
[]
string
{
s
.
storage
.
ImageRoot
(),
s
.
storage
.
VideoRoot
()}
for
_
,
root
:=
range
roots
{
if
root
==
""
{
continue
}
_
=
filepath
.
Walk
(
root
,
func
(
p
string
,
info
os
.
FileInfo
,
err
error
)
error
{
if
err
!=
nil
{
return
nil
}
if
info
.
IsDir
()
{
return
nil
}
if
info
.
ModTime
()
.
Before
(
cutoff
)
{
if
rmErr
:=
os
.
Remove
(
p
);
rmErr
==
nil
{
deleted
++
}
}
return
nil
})
}
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] cleanup finished, deleted=%d"
,
deleted
)
}
backend/internal/service/sora_media_cleanup_service_test.go
deleted
100644 → 0
View file @
dbb248df
//go:build unit
package
service
import
(
"os"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSoraMediaCleanupService_RunCleanup_NilCfg
(
t
*
testing
.
T
)
{
storage
:=
&
SoraMediaStorage
{}
svc
:=
&
SoraMediaCleanupService
{
storage
:
storage
,
cfg
:
nil
}
// 不应 panic
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_RunCleanup_NilStorage
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
&
SoraMediaCleanupService
{
storage
:
nil
,
cfg
:
cfg
}
// 不应 panic
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_RunCleanup_ZeroRetention
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
RetentionDays
:
0
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
// retention=0 应跳过清理
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_Start_NilCfg
(
t
*
testing
.
T
)
{
svc
:=
NewSoraMediaCleanupService
(
nil
,
nil
)
svc
.
Start
()
// cfg == nil 时应直接返回
}
func
TestSoraMediaCleanupService_Start_StorageDisabled
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
},
},
},
}
svc
:=
NewSoraMediaCleanupService
(
nil
,
cfg
)
svc
.
Start
()
// storage == nil 时应直接返回
}
func
TestSoraMediaCleanupService_Start_WithTimezone
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Timezone
:
"Asia/Shanghai"
,
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"0 3 * * *"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
}
func
TestSoraMediaCleanupService_Start_Disabled
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
false
,
},
},
},
}
svc
:=
NewSoraMediaCleanupService
(
nil
,
cfg
)
svc
.
Start
()
// 不应 panic,也不应启动 cron
}
func
TestSoraMediaCleanupService_Start_NilSelf
(
t
*
testing
.
T
)
{
var
svc
*
SoraMediaCleanupService
svc
.
Start
()
// 不应 panic
}
func
TestSoraMediaCleanupService_Start_EmptySchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
""
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
// 空 schedule 不应启动
}
func
TestSoraMediaCleanupService_Start_InvalidSchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"invalid-cron"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
// 无效 schedule 不应 panic
}
func
TestSoraMediaCleanupService_Start_ValidSchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"0 3 * * *"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
}
func
TestSoraMediaCleanupService_Stop_NilSelf
(
t
*
testing
.
T
)
{
var
svc
*
SoraMediaCleanupService
svc
.
Stop
()
// 不应 panic
}
func
TestSoraMediaCleanupService_Stop_WithoutStart
(
t
*
testing
.
T
)
{
svc
:=
NewSoraMediaCleanupService
(
nil
,
&
config
.
Config
{})
svc
.
Stop
()
// cron 未启动时 Stop 不应 panic
}
func
TestSoraMediaCleanupService_RunCleanup
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
RetentionDays
:
1
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
require
.
NoError
(
t
,
storage
.
EnsureLocalDirs
())
oldImage
:=
filepath
.
Join
(
storage
.
ImageRoot
(),
"old.png"
)
newVideo
:=
filepath
.
Join
(
storage
.
VideoRoot
(),
"new.mp4"
)
require
.
NoError
(
t
,
os
.
WriteFile
(
oldImage
,
[]
byte
(
"old"
),
0
o644
))
require
.
NoError
(
t
,
os
.
WriteFile
(
newVideo
,
[]
byte
(
"new"
),
0
o644
))
oldTime
:=
time
.
Now
()
.
Add
(
-
48
*
time
.
Hour
)
require
.
NoError
(
t
,
os
.
Chtimes
(
oldImage
,
oldTime
,
oldTime
))
cleanup
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
cleanup
.
runCleanup
()
require
.
NoFileExists
(
t
,
oldImage
)
require
.
FileExists
(
t
,
newVideo
)
}
backend/internal/service/sora_media_sign.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strconv"
"strings"
)
// SignSoraMediaURL 生成 Sora 媒体临时签名
func
SignSoraMediaURL
(
path
string
,
query
string
,
expires
int64
,
key
string
)
string
{
key
=
strings
.
TrimSpace
(
key
)
if
key
==
""
{
return
""
}
mac
:=
hmac
.
New
(
sha256
.
New
,
[]
byte
(
key
))
if
_
,
err
:=
mac
.
Write
([]
byte
(
buildSoraMediaSignPayload
(
path
,
query
)));
err
!=
nil
{
return
""
}
if
_
,
err
:=
mac
.
Write
([]
byte
(
"|"
));
err
!=
nil
{
return
""
}
if
_
,
err
:=
mac
.
Write
([]
byte
(
strconv
.
FormatInt
(
expires
,
10
)));
err
!=
nil
{
return
""
}
return
hex
.
EncodeToString
(
mac
.
Sum
(
nil
))
}
// VerifySoraMediaURL 校验 Sora 媒体签名
func
VerifySoraMediaURL
(
path
string
,
query
string
,
expires
int64
,
signature
string
,
key
string
)
bool
{
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
return
false
}
expected
:=
SignSoraMediaURL
(
path
,
query
,
expires
,
key
)
if
expected
==
""
{
return
false
}
return
hmac
.
Equal
([]
byte
(
signature
),
[]
byte
(
expected
))
}
func
buildSoraMediaSignPayload
(
path
string
,
query
string
)
string
{
if
strings
.
TrimSpace
(
query
)
==
""
{
return
path
}
return
path
+
"?"
+
query
}
backend/internal/service/sora_media_sign_test.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
"testing"
func
TestSoraMediaSignVerify
(
t
*
testing
.
T
)
{
key
:=
"test-key"
path
:=
"/tmp/abc.png"
query
:=
"a=1&b=2"
expires
:=
int64
(
1700000000
)
signature
:=
SignSoraMediaURL
(
path
,
query
,
expires
,
key
)
if
signature
==
""
{
t
.
Fatal
(
"签名为空"
)
}
if
!
VerifySoraMediaURL
(
path
,
query
,
expires
,
signature
,
key
)
{
t
.
Fatal
(
"签名校验失败"
)
}
if
VerifySoraMediaURL
(
path
,
"a=1"
,
expires
,
signature
,
key
)
{
t
.
Fatal
(
"签名参数不同仍然通过"
)
}
if
VerifySoraMediaURL
(
path
,
query
,
expires
+
1
,
signature
,
key
)
{
t
.
Fatal
(
"签名过期校验未失败"
)
}
}
func
TestSoraMediaSignWithEmptyKey
(
t
*
testing
.
T
)
{
signature
:=
SignSoraMediaURL
(
"/tmp/a.png"
,
"a=1"
,
1
,
""
)
if
signature
!=
""
{
t
.
Fatalf
(
"空密钥不应生成签名"
)
}
if
VerifySoraMediaURL
(
"/tmp/a.png"
,
"a=1"
,
1
,
"sig"
,
""
)
{
t
.
Fatalf
(
"空密钥不应通过校验"
)
}
}
backend/internal/service/sora_media_storage.go
deleted
100644 → 0
View file @
dbb248df
package
service
import
(
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
)
const
(
soraStorageDefaultRoot
=
"/app/data/sora"
)
// SoraMediaStorage 负责下载并落地 Sora 媒体
type
SoraMediaStorage
struct
{
cfg
*
config
.
Config
root
string
imageRoot
string
videoRoot
string
downloadTimeout
time
.
Duration
maxDownloadBytes
int64
fallbackToUpstream
bool
debug
bool
sem
chan
struct
{}
ready
bool
}
func
NewSoraMediaStorage
(
cfg
*
config
.
Config
)
*
SoraMediaStorage
{
storage
:=
&
SoraMediaStorage
{
cfg
:
cfg
}
storage
.
refreshConfig
()
if
storage
.
Enabled
()
{
if
err
:=
storage
.
EnsureLocalDirs
();
err
!=
nil
{
log
.
Printf
(
"[SoraStorage] 初始化失败: %v"
,
err
)
}
}
return
storage
}
func
(
s
*
SoraMediaStorage
)
Enabled
()
bool
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
false
}
return
strings
.
ToLower
(
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
Type
))
==
"local"
}
func
(
s
*
SoraMediaStorage
)
Root
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
root
}
func
(
s
*
SoraMediaStorage
)
ImageRoot
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
imageRoot
}
func
(
s
*
SoraMediaStorage
)
VideoRoot
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
videoRoot
}
func
(
s
*
SoraMediaStorage
)
refreshConfig
()
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
}
root
:=
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
LocalPath
)
if
root
==
""
{
root
=
soraStorageDefaultRoot
}
root
=
filepath
.
Clean
(
root
)
if
!
filepath
.
IsAbs
(
root
)
{
if
absRoot
,
err
:=
filepath
.
Abs
(
root
);
err
==
nil
{
root
=
absRoot
}
}
s
.
root
=
root
s
.
imageRoot
=
filepath
.
Join
(
root
,
"image"
)
s
.
videoRoot
=
filepath
.
Join
(
root
,
"video"
)
maxConcurrent
:=
s
.
cfg
.
Sora
.
Storage
.
MaxConcurrentDownloads
if
maxConcurrent
<=
0
{
maxConcurrent
=
4
}
timeoutSeconds
:=
s
.
cfg
.
Sora
.
Storage
.
DownloadTimeoutSeconds
if
timeoutSeconds
<=
0
{
timeoutSeconds
=
120
}
s
.
downloadTimeout
=
time
.
Duration
(
timeoutSeconds
)
*
time
.
Second
maxBytes
:=
s
.
cfg
.
Sora
.
Storage
.
MaxDownloadBytes
if
maxBytes
<=
0
{
maxBytes
=
200
<<
20
}
s
.
maxDownloadBytes
=
maxBytes
s
.
fallbackToUpstream
=
s
.
cfg
.
Sora
.
Storage
.
FallbackToUpstream
s
.
debug
=
s
.
cfg
.
Sora
.
Storage
.
Debug
s
.
sem
=
make
(
chan
struct
{},
maxConcurrent
)
}
// EnsureLocalDirs 创建并校验本地目录
func
(
s
*
SoraMediaStorage
)
EnsureLocalDirs
()
error
{
if
s
==
nil
||
!
s
.
Enabled
()
{
return
nil
}
if
err
:=
os
.
MkdirAll
(
s
.
imageRoot
,
0
o755
);
err
!=
nil
{
return
fmt
.
Errorf
(
"create image dir: %w"
,
err
)
}
if
err
:=
os
.
MkdirAll
(
s
.
videoRoot
,
0
o755
);
err
!=
nil
{
return
fmt
.
Errorf
(
"create video dir: %w"
,
err
)
}
s
.
ready
=
true
return
nil
}
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
func
(
s
*
SoraMediaStorage
)
StoreFromURLs
(
ctx
context
.
Context
,
mediaType
string
,
urls
[]
string
)
([]
string
,
error
)
{
if
len
(
urls
)
==
0
{
return
nil
,
nil
}
if
s
==
nil
||
!
s
.
Enabled
()
{
return
urls
,
nil
}
if
!
s
.
ready
{
if
err
:=
s
.
EnsureLocalDirs
();
err
!=
nil
{
return
nil
,
err
}
}
results
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
raw
:=
range
urls
{
relative
,
err
:=
s
.
downloadAndStore
(
ctx
,
mediaType
,
raw
)
if
err
!=
nil
{
if
s
.
fallbackToUpstream
{
results
=
append
(
results
,
raw
)
continue
}
return
nil
,
err
}
results
=
append
(
results
,
relative
)
}
return
results
,
nil
}
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
func
(
s
*
SoraMediaStorage
)
TotalSizeByRelativePaths
(
paths
[]
string
)
(
int64
,
error
)
{
if
s
==
nil
||
len
(
paths
)
==
0
{
return
0
,
nil
}
var
total
int64
for
_
,
p
:=
range
paths
{
localPath
,
err
:=
s
.
resolveLocalPath
(
p
)
if
err
!=
nil
{
continue
}
info
,
err
:=
os
.
Stat
(
localPath
)
if
err
!=
nil
{
if
os
.
IsNotExist
(
err
)
{
continue
}
return
0
,
err
}
if
info
.
Mode
()
.
IsRegular
()
{
total
+=
info
.
Size
()
}
}
return
total
,
nil
}
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
func
(
s
*
SoraMediaStorage
)
DeleteByRelativePaths
(
paths
[]
string
)
error
{
if
s
==
nil
||
len
(
paths
)
==
0
{
return
nil
}
var
lastErr
error
for
_
,
p
:=
range
paths
{
localPath
,
err
:=
s
.
resolveLocalPath
(
p
)
if
err
!=
nil
{
continue
}
if
err
:=
os
.
Remove
(
localPath
);
err
!=
nil
&&
!
os
.
IsNotExist
(
err
)
{
lastErr
=
err
}
}
return
lastErr
}
func
(
s
*
SoraMediaStorage
)
resolveLocalPath
(
relativePath
string
)
(
string
,
error
)
{
if
s
==
nil
||
strings
.
TrimSpace
(
relativePath
)
==
""
{
return
""
,
errors
.
New
(
"empty path"
)
}
cleaned
:=
path
.
Clean
(
relativePath
)
if
!
strings
.
HasPrefix
(
cleaned
,
"/image/"
)
&&
!
strings
.
HasPrefix
(
cleaned
,
"/video/"
)
{
return
""
,
errors
.
New
(
"not a local media path"
)
}
if
strings
.
TrimSpace
(
s
.
root
)
==
""
{
return
""
,
errors
.
New
(
"storage root not configured"
)
}
relative
:=
strings
.
TrimPrefix
(
cleaned
,
"/"
)
return
filepath
.
Join
(
s
.
root
,
filepath
.
FromSlash
(
relative
)),
nil
}
func
(
s
*
SoraMediaStorage
)
downloadAndStore
(
ctx
context
.
Context
,
mediaType
,
rawURL
string
)
(
string
,
error
)
{
if
strings
.
TrimSpace
(
rawURL
)
==
""
{
return
""
,
errors
.
New
(
"empty url"
)
}
root
:=
s
.
imageRoot
if
mediaType
==
"video"
{
root
=
s
.
videoRoot
}
if
root
==
""
{
return
""
,
errors
.
New
(
"storage root not configured"
)
}
retries
:=
3
for
attempt
:=
1
;
attempt
<=
retries
;
attempt
++
{
release
,
err
:=
s
.
acquire
(
ctx
)
if
err
!=
nil
{
return
""
,
err
}
relative
,
err
:=
s
.
downloadOnce
(
ctx
,
root
,
mediaType
,
rawURL
)
release
()
if
err
==
nil
{
return
relative
,
nil
}
if
s
.
debug
{
log
.
Printf
(
"[SoraStorage] 下载失败(%d/%d): %s err=%v"
,
attempt
,
retries
,
sanitizeMediaLogURL
(
rawURL
),
err
)
}
if
attempt
<
retries
{
time
.
Sleep
(
time
.
Duration
(
attempt
*
attempt
)
*
time
.
Second
)
continue
}
return
""
,
err
}
return
""
,
errors
.
New
(
"download retries exhausted"
)
}
func
(
s
*
SoraMediaStorage
)
downloadOnce
(
ctx
context
.
Context
,
root
,
mediaType
,
rawURL
string
)
(
string
,
error
)
{
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
rawURL
,
nil
)
if
err
!=
nil
{
return
""
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
s
.
downloadTimeout
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
return
""
,
fmt
.
Errorf
(
"download failed: %d %s"
,
resp
.
StatusCode
,
string
(
body
))
}
ext
:=
normalizeSoraFileExt
(
fileExtFromURL
(
rawURL
))
if
ext
==
""
{
ext
=
normalizeSoraFileExt
(
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
)))
}
if
ext
==
""
{
ext
=
".bin"
}
if
s
.
maxDownloadBytes
>
0
&&
resp
.
ContentLength
>
s
.
maxDownloadBytes
{
return
""
,
fmt
.
Errorf
(
"download size exceeds limit: %d"
,
resp
.
ContentLength
)
}
storageRoot
,
err
:=
os
.
OpenRoot
(
root
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
storageRoot
.
Close
()
}()
datePath
:=
time
.
Now
()
.
Format
(
"2006/01/02"
)
datePathFS
:=
filepath
.
FromSlash
(
datePath
)
if
err
:=
storageRoot
.
MkdirAll
(
datePathFS
,
0
o755
);
err
!=
nil
{
return
""
,
err
}
filename
:=
uuid
.
NewString
()
+
ext
filePath
:=
filepath
.
Join
(
datePathFS
,
filename
)
out
,
err
:=
storageRoot
.
OpenFile
(
filePath
,
os
.
O_CREATE
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
0
o644
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
out
.
Close
()
}()
limited
:=
io
.
LimitReader
(
resp
.
Body
,
s
.
maxDownloadBytes
+
1
)
written
,
err
:=
io
.
Copy
(
out
,
limited
)
if
err
!=
nil
{
removePartialDownload
(
storageRoot
,
filePath
)
return
""
,
err
}
if
s
.
maxDownloadBytes
>
0
&&
written
>
s
.
maxDownloadBytes
{
removePartialDownload
(
storageRoot
,
filePath
)
return
""
,
fmt
.
Errorf
(
"download size exceeds limit: %d"
,
written
)
}
relative
:=
path
.
Join
(
"/"
,
mediaType
,
datePath
,
filename
)
if
s
.
debug
{
log
.
Printf
(
"[SoraStorage] 已落地 %s -> %s"
,
sanitizeMediaLogURL
(
rawURL
),
relative
)
}
return
relative
,
nil
}
func
(
s
*
SoraMediaStorage
)
acquire
(
ctx
context
.
Context
)
(
func
(),
error
)
{
if
s
.
sem
==
nil
{
return
func
()
{},
nil
}
select
{
case
s
.
sem
<-
struct
{}{}
:
return
func
()
{
<-
s
.
sem
},
nil
case
<-
ctx
.
Done
()
:
return
nil
,
ctx
.
Err
()
}
}
func
fileExtFromURL
(
raw
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
""
}
ext
:=
path
.
Ext
(
parsed
.
Path
)
return
strings
.
ToLower
(
ext
)
}
func
fileExtFromContentType
(
ct
string
)
string
{
if
ct
==
""
{
return
""
}
if
exts
,
err
:=
mime
.
ExtensionsByType
(
ct
);
err
==
nil
&&
len
(
exts
)
>
0
{
return
strings
.
ToLower
(
exts
[
0
])
}
return
""
}
func
normalizeSoraFileExt
(
ext
string
)
string
{
ext
=
strings
.
ToLower
(
strings
.
TrimSpace
(
ext
))
switch
ext
{
case
".png"
,
".jpg"
,
".jpeg"
,
".gif"
,
".webp"
,
".bmp"
,
".svg"
,
".tif"
,
".tiff"
,
".heic"
,
".mp4"
,
".mov"
,
".webm"
,
".m4v"
,
".avi"
,
".mkv"
,
".3gp"
,
".flv"
:
return
ext
default
:
return
""
}
}
func
removePartialDownload
(
root
*
os
.
Root
,
filePath
string
)
{
if
root
==
nil
||
strings
.
TrimSpace
(
filePath
)
==
""
{
return
}
_
=
root
.
Remove
(
filePath
)
}
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
func
sanitizeMediaLogURL
(
rawURL
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
rawURL
)
if
err
!=
nil
{
if
len
(
rawURL
)
>
80
{
return
rawURL
[
:
80
]
+
"..."
}
return
rawURL
}
safe
:=
parsed
.
Scheme
+
"://"
+
parsed
.
Host
+
parsed
.
Path
if
len
(
safe
)
>
120
{
return
safe
[
:
120
]
+
"..."
}
return
safe
}
backend/internal/service/sora_media_storage_test.go
deleted
100644 → 0
View file @
dbb248df
//go:build unit
package
service
import
(
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSoraMediaStorage_StoreFromURLs
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"image/png"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"data"
))
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
MaxConcurrentDownloads
:
1
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
urls
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
server
.
URL
+
"/img.png"
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
urls
,
1
)
require
.
True
(
t
,
strings
.
HasPrefix
(
urls
[
0
],
"/image/"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
urls
[
0
],
".png"
))
localPath
:=
filepath
.
Join
(
tmpDir
,
filepath
.
FromSlash
(
strings
.
TrimPrefix
(
urls
[
0
],
"/"
)))
require
.
FileExists
(
t
,
localPath
)
}
func
TestSoraMediaStorage_FallbackToUpstream
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
FallbackToUpstream
:
true
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
url
:=
server
.
URL
+
"/broken.png"
urls
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
url
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
url
},
urls
)
}
func
TestSoraMediaStorage_MaxDownloadBytes
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"image/png"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"too-large"
))
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
MaxDownloadBytes
:
1
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
_
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
server
.
URL
+
"/img.png"
})
require
.
Error
(
t
,
err
)
}
func
TestNormalizeSoraFileExt
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
".png"
,
normalizeSoraFileExt
(
".PNG"
))
require
.
Equal
(
t
,
".mp4"
,
normalizeSoraFileExt
(
".mp4"
))
require
.
Equal
(
t
,
""
,
normalizeSoraFileExt
(
"../../etc/passwd"
))
require
.
Equal
(
t
,
""
,
normalizeSoraFileExt
(
".php"
))
}
func
TestRemovePartialDownload
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
root
,
err
:=
os
.
OpenRoot
(
tmpDir
)
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
root
.
Close
()
}()
filePath
:=
"partial.bin"
f
,
err
:=
root
.
OpenFile
(
filePath
,
os
.
O_CREATE
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
0
o600
)
require
.
NoError
(
t
,
err
)
_
,
_
=
f
.
WriteString
(
"partial"
)
_
=
f
.
Close
()
removePartialDownload
(
root
,
filePath
)
_
,
err
=
root
.
Stat
(
filePath
)
require
.
Error
(
t
,
err
)
require
.
True
(
t
,
os
.
IsNotExist
(
err
))
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment