Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
d757df8a
Unverified
Commit
d757df8a
authored
Apr 05, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 05, 2026
Browse files
Merge pull request #1463 from touwaeriol/feat/remove-sora
revert: completely remove Sora platform
parents
f585a15e
19655a15
Changes
163
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/gateway_service_selection_failure_stats_test.go
View file @
d757df8a
...
...
@@ -9,35 +9,35 @@ import (
func
TestCollectSelectionFailureStats
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
model
:=
"
sora2-landscape-10s
"
model
:=
"
gpt-5.4
"
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
accounts
:=
[]
Account
{
// excluded
{
ID
:
1
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
},
// unschedulable
{
ID
:
2
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
false
,
},
// platform filtered
{
ID
:
3
,
Platform
:
Platform
OpenAI
,
Platform
:
Platform
Antigravity
,
Status
:
StatusActive
,
Schedulable
:
true
,
},
// model unsupported
{
ID
:
4
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
...
...
@@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// model rate limited
{
ID
:
5
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
...
...
@@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// eligible
{
ID
:
6
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
},
}
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
stats
:=
svc
.
collectSelectionFailureStats
(
context
.
Background
(),
accounts
,
model
,
Platform
Sora
,
excluded
,
false
)
stats
:=
svc
.
collectSelectionFailureStats
(
context
.
Background
(),
accounts
,
model
,
Platform
OpenAI
,
excluded
,
false
)
if
stats
.
Total
!=
6
{
t
.
Fatalf
(
"total=%d want=6"
,
stats
.
Total
)
...
...
@@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
}
}
func
TestDiagnoseSelectionFailure_
Sora
UnschedulableDetail
(
t
*
testing
.
T
)
{
func
TestDiagnoseSelectionFailure_UnschedulableDetail
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
acc
:=
&
Account
{
ID
:
7
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
false
,
}
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
"
sora2-landscape-10s
"
,
Platform
Sora
,
map
[
int64
]
struct
{}{},
false
)
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
"
gpt-5.4
"
,
Platform
OpenAI
,
map
[
int64
]
struct
{}{},
false
)
if
diagnosis
.
Category
!=
"unschedulable"
{
t
.
Fatalf
(
"category=%s want=unschedulable"
,
diagnosis
.
Category
)
}
if
diagnosis
.
Detail
!=
"schedulable
=false
"
{
t
.
Fatalf
(
"detail=%s want=schedulable
=false
"
,
diagnosis
.
Detail
)
if
diagnosis
.
Detail
!=
"
generic_un
schedulable"
{
t
.
Fatalf
(
"detail=%s want=
generic_un
schedulable"
,
diagnosis
.
Detail
)
}
}
func
TestDiagnoseSelectionFailure_
Sora
ModelRateLimitedDetail
(
t
*
testing
.
T
)
{
func
TestDiagnoseSelectionFailure_ModelRateLimitedDetail
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
model
:=
"
sora2-landscape-10s
"
model
:=
"
gpt-5.4
"
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
.
UTC
()
.
Format
(
time
.
RFC3339
)
acc
:=
&
Account
{
ID
:
8
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
...
...
@@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
},
}
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
model
,
Platform
Sora
,
map
[
int64
]
struct
{}{},
false
)
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
model
,
Platform
OpenAI
,
map
[
int64
]
struct
{}{},
false
)
if
diagnosis
.
Category
!=
"model_rate_limited"
{
t
.
Fatalf
(
"category=%s want=model_rate_limited"
,
diagnosis
.
Category
)
}
...
...
backend/internal/service/gateway_service_sora_model_support_test.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
"testing"
func
TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected sora model to be supported when model_mapping is empty"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-4o"
:
"gpt-4o"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected sora model to be supported when mapping has no sora selectors"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"sora2"
:
"sora2"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-15s"
)
{
t
.
Fatalf
(
"expected family selector sora2 to support sora2-landscape-15s"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"sy_8"
:
"sy_8"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected underlying model selector sy_8 to support sora2-landscape-10s"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-image"
:
"gpt-image"
,
},
},
}
if
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected video model to be blocked when mapping explicitly only allows gpt-image"
)
}
}
backend/internal/service/gateway_service_sora_scheduling_test.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"testing"
"time"
)
func
TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
now
:=
time
.
Now
()
past
:=
now
.
Add
(
-
1
*
time
.
Minute
)
future
:=
now
.
Add
(
5
*
time
.
Minute
)
acc
:=
&
Account
{
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Schedulable
:
true
,
AutoPauseOnExpired
:
true
,
ExpiresAt
:
&
past
,
OverloadUntil
:
&
future
,
RateLimitResetAt
:
&
future
,
}
if
!
svc
.
isAccountSchedulableForSelection
(
acc
)
{
t
.
Fatalf
(
"expected sora account to ignore generic expiry/overload/rate-limit windows"
)
}
}
func
TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
future
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
acc
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
&
future
,
}
if
svc
.
isAccountSchedulableForSelection
(
acc
)
{
t
.
Fatalf
(
"expected non-sora account to keep generic schedulable checks"
)
}
}
func
TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
model
:=
"sora2-landscape-10s"
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
.
UTC
()
.
Format
(
time
.
RFC3339
)
globalResetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
acc
:=
&
Account
{
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
&
globalResetAt
,
Extra
:
map
[
string
]
any
{
"model_rate_limits"
:
map
[
string
]
any
{
model
:
map
[
string
]
any
{
"rate_limit_reset_at"
:
resetAt
,
},
},
},
}
if
svc
.
isAccountSchedulableForModelSelection
(
context
.
Background
(),
acc
,
model
)
{
t
.
Fatalf
(
"expected sora account to be blocked by model scope rate limit"
)
}
}
func
TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
future
:=
time
.
Now
()
.
Add
(
3
*
time
.
Minute
)
accounts
:=
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
&
future
,
},
}
stats
:=
svc
.
collectSelectionFailureStats
(
context
.
Background
(),
accounts
,
"sora2-landscape-10s"
,
PlatformSora
,
map
[
int64
]
struct
{}{},
false
)
if
stats
.
Unschedulable
!=
0
||
stats
.
Eligible
!=
1
{
t
.
Fatalf
(
"unexpected stats: unschedulable=%d eligible=%d"
,
stats
.
Unschedulable
,
stats
.
Eligible
)
}
}
backend/internal/service/group.go
View file @
d757df8a
...
...
@@ -26,15 +26,6 @@ type Group struct {
ImagePrice2K
*
float64
ImagePrice4K
*
float64
// Sora 按次计费配置(阶段 1)
SoraImagePrice360
*
float64
SoraImagePrice540
*
float64
SoraVideoPricePerRequest
*
float64
SoraVideoPricePerRequestHD
*
float64
// Sora 存储配额
SoraStorageQuotaBytes
int64
// Claude Code 客户端限制
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
...
...
@@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
}
}
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
func
(
g
*
Group
)
GetSoraImagePrice
(
imageSize
string
)
*
float64
{
switch
imageSize
{
case
"360"
:
return
g
.
SoraImagePrice360
case
"540"
:
return
g
.
SoraImagePrice540
default
:
return
g
.
SoraImagePrice360
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func
IsGroupContextValid
(
group
*
Group
)
bool
{
if
group
==
nil
{
...
...
backend/internal/service/openai_oauth_service.go
View file @
d757df8a
...
...
@@ -3,30 +3,15 @@ package service
import
(
"context"
"crypto/subtle"
"encoding/json"
"io"
"log/slog"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
var
openAISoraSessionAuthURL
=
"https://sora.chatgpt.com/api/auth/session"
var
soraSessionCookiePattern
=
regexp
.
MustCompile
(
`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`
)
type
soraSessionChunk
struct
{
index
int
value
string
}
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type
OpenAIOAuthService
struct
{
sessionStore
*
openai
.
SessionStore
...
...
@@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return
s
.
RefreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
""
)
}
// RefreshTokenWithClientID refreshes an OpenAI
/Sora
OAuth token with optional client_id.
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
func
(
s
*
OpenAIOAuthService
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
string
,
proxyURL
string
,
clientID
string
)
(
*
OpenAITokenInfo
,
error
)
{
tokenResp
,
err
:=
s
.
oauthClient
.
RefreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
clientID
)
if
err
!=
nil
{
...
...
@@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
tokenInfo
.
PrivacyMode
=
disableOpenAITraining
(
ctx
,
s
.
privacyClientFactory
,
tokenInfo
.
AccessToken
,
proxyURL
)
}
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func
(
s
*
OpenAIOAuthService
)
ExchangeSoraSessionToken
(
ctx
context
.
Context
,
sessionToken
string
,
proxyID
*
int64
)
(
*
OpenAITokenInfo
,
error
)
{
sessionToken
=
normalizeSoraSessionTokenInput
(
sessionToken
)
if
strings
.
TrimSpace
(
sessionToken
)
==
""
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"SORA_SESSION_TOKEN_REQUIRED"
,
"session_token is required"
)
}
proxyURL
,
err
:=
s
.
resolveProxyURL
(
ctx
,
proxyID
)
if
err
!=
nil
{
return
nil
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
openAISoraSessionAuthURL
,
nil
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusInternalServerError
,
"SORA_SESSION_REQUEST_BUILD_FAILED"
,
"failed to build request: %v"
,
err
)
}
req
.
Header
.
Set
(
"Cookie"
,
"__Secure-next-auth.session-token="
+
strings
.
TrimSpace
(
sessionToken
))
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
req
.
Header
.
Set
(
"Origin"
,
"https://sora.chatgpt.com"
)
req
.
Header
.
Set
(
"Referer"
,
"https://sora.chatgpt.com/"
)
req
.
Header
.
Set
(
"User-Agent"
,
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
120
*
time
.
Second
,
})
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_CLIENT_FAILED"
,
"create http client failed: %v"
,
err
)
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_REQUEST_FAILED"
,
"request failed: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_EXCHANGE_FAILED"
,
"status %d: %s"
,
resp
.
StatusCode
,
strings
.
TrimSpace
(
string
(
body
)))
}
var
sessionResp
struct
{
AccessToken
string
`json:"accessToken"`
Expires
string
`json:"expires"`
User
struct
{
Email
string
`json:"email"`
Name
string
`json:"name"`
}
`json:"user"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
sessionResp
);
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_PARSE_FAILED"
,
"failed to parse response: %v"
,
err
)
}
if
strings
.
TrimSpace
(
sessionResp
.
AccessToken
)
==
""
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadGateway
,
"SORA_SESSION_ACCESS_TOKEN_MISSING"
,
"session exchange response missing access token"
)
}
expiresAt
:=
time
.
Now
()
.
Add
(
time
.
Hour
)
.
Unix
()
if
strings
.
TrimSpace
(
sessionResp
.
Expires
)
!=
""
{
if
parsed
,
parseErr
:=
time
.
Parse
(
time
.
RFC3339
,
sessionResp
.
Expires
);
parseErr
==
nil
{
expiresAt
=
parsed
.
Unix
()
}
}
expiresIn
:=
expiresAt
-
time
.
Now
()
.
Unix
()
if
expiresIn
<
0
{
expiresIn
=
0
}
return
&
OpenAITokenInfo
{
AccessToken
:
strings
.
TrimSpace
(
sessionResp
.
AccessToken
),
ExpiresIn
:
expiresIn
,
ExpiresAt
:
expiresAt
,
ClientID
:
openai
.
SoraClientID
,
Email
:
strings
.
TrimSpace
(
sessionResp
.
User
.
Email
),
},
nil
}
func
normalizeSoraSessionTokenInput
(
raw
string
)
string
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
""
}
matches
:=
soraSessionCookiePattern
.
FindAllStringSubmatch
(
trimmed
,
-
1
)
if
len
(
matches
)
==
0
{
return
sanitizeSessionToken
(
trimmed
)
}
chunkMatches
:=
make
([]
soraSessionChunk
,
0
,
len
(
matches
))
singleValues
:=
make
([]
string
,
0
,
len
(
matches
))
for
_
,
match
:=
range
matches
{
if
len
(
match
)
<
3
{
continue
}
value
:=
sanitizeSessionToken
(
match
[
2
])
if
value
==
""
{
continue
}
if
strings
.
TrimSpace
(
match
[
1
])
==
""
{
singleValues
=
append
(
singleValues
,
value
)
continue
}
idx
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
match
[
1
]))
if
err
!=
nil
||
idx
<
0
{
continue
}
chunkMatches
=
append
(
chunkMatches
,
soraSessionChunk
{
index
:
idx
,
value
:
value
,
})
}
if
merged
:=
mergeLatestSoraSessionChunks
(
chunkMatches
);
merged
!=
""
{
return
merged
}
if
len
(
singleValues
)
>
0
{
return
singleValues
[
len
(
singleValues
)
-
1
]
}
return
""
}
func
mergeSoraSessionChunkSegment
(
chunks
[]
soraSessionChunk
,
requiredMaxIndex
int
,
requireComplete
bool
)
string
{
if
len
(
chunks
)
==
0
{
return
""
}
byIndex
:=
make
(
map
[
int
]
string
,
len
(
chunks
))
for
_
,
chunk
:=
range
chunks
{
byIndex
[
chunk
.
index
]
=
chunk
.
value
}
if
_
,
ok
:=
byIndex
[
0
];
!
ok
{
return
""
}
if
requireComplete
{
for
idx
:=
0
;
idx
<=
requiredMaxIndex
;
idx
++
{
if
_
,
ok
:=
byIndex
[
idx
];
!
ok
{
return
""
}
}
}
orderedIndexes
:=
make
([]
int
,
0
,
len
(
byIndex
))
for
idx
:=
range
byIndex
{
orderedIndexes
=
append
(
orderedIndexes
,
idx
)
}
sort
.
Ints
(
orderedIndexes
)
var
builder
strings
.
Builder
for
_
,
idx
:=
range
orderedIndexes
{
if
_
,
err
:=
builder
.
WriteString
(
byIndex
[
idx
]);
err
!=
nil
{
return
""
}
}
return
sanitizeSessionToken
(
builder
.
String
())
}
func
mergeLatestSoraSessionChunks
(
chunks
[]
soraSessionChunk
)
string
{
if
len
(
chunks
)
==
0
{
return
""
}
requiredMaxIndex
:=
0
for
_
,
chunk
:=
range
chunks
{
if
chunk
.
index
>
requiredMaxIndex
{
requiredMaxIndex
=
chunk
.
index
}
}
groupStarts
:=
make
([]
int
,
0
,
len
(
chunks
))
for
idx
,
chunk
:=
range
chunks
{
if
chunk
.
index
==
0
{
groupStarts
=
append
(
groupStarts
,
idx
)
}
}
if
len
(
groupStarts
)
==
0
{
return
mergeSoraSessionChunkSegment
(
chunks
,
requiredMaxIndex
,
false
)
}
for
i
:=
len
(
groupStarts
)
-
1
;
i
>=
0
;
i
--
{
start
:=
groupStarts
[
i
]
end
:=
len
(
chunks
)
if
i
+
1
<
len
(
groupStarts
)
{
end
=
groupStarts
[
i
+
1
]
}
if
merged
:=
mergeSoraSessionChunkSegment
(
chunks
[
start
:
end
],
requiredMaxIndex
,
true
);
merged
!=
""
{
return
merged
}
}
return
mergeSoraSessionChunkSegment
(
chunks
,
requiredMaxIndex
,
false
)
}
func
sanitizeSessionToken
(
raw
string
)
string
{
token
:=
strings
.
TrimSpace
(
raw
)
token
=
strings
.
Trim
(
token
,
"
\"
'`"
)
token
=
strings
.
TrimSuffix
(
token
,
";"
)
return
strings
.
TrimSpace
(
token
)
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
// RefreshAccountToken refreshes token for an OpenAI OAuth account
func
(
s
*
OpenAIOAuthService
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
if
account
.
Platform
!=
PlatformOpenAI
&&
account
.
Platform
!=
PlatformSora
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT"
,
"account is not an OpenAI
/Sora
account"
)
if
account
.
Platform
!=
PlatformOpenAI
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT"
,
"account is not an OpenAI account"
)
}
if
account
.
Type
!=
AccountTypeOAuth
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT_TYPE"
,
"account is not an OAuth account"
)
...
...
@@ -594,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() {
s
.
sessionStore
.
Stop
()
}
func
(
s
*
OpenAIOAuthService
)
resolveProxyURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
string
,
error
)
{
if
proxyID
==
nil
{
return
""
,
nil
}
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
proxyID
)
if
err
!=
nil
{
return
""
,
infraerrors
.
Newf
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_PROXY_NOT_FOUND"
,
"proxy not found: %v"
,
err
)
}
if
proxy
==
nil
{
return
""
,
nil
}
return
proxy
.
URL
(),
nil
}
func
normalizeOpenAIOAuthPlatform
(
platform
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
platform
))
{
case
PlatformSora
:
return
openai
.
OAuthPlatformSora
default
:
return
openai
.
OAuthPlatformOpenAI
}
return
openai
.
OAuthPlatformOpenAI
}
backend/internal/service/openai_oauth_service_auth_url_test.go
View file @
d757df8a
...
...
@@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
ClientID
,
session
.
ClientID
)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
func
TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient
(
t
*
testing
.
T
)
{
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientAuthURLStub
{})
defer
svc
.
Stop
()
result
,
err
:=
svc
.
GenerateAuthURL
(
context
.
Background
(),
nil
,
""
,
PlatformSora
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
result
.
AuthURL
)
require
.
NotEmpty
(
t
,
result
.
SessionID
)
parsed
,
err
:=
url
.
Parse
(
result
.
AuthURL
)
require
.
NoError
(
t
,
err
)
q
:=
parsed
.
Query
()
require
.
Equal
(
t
,
openai
.
ClientID
,
q
.
Get
(
"client_id"
))
require
.
Empty
(
t
,
q
.
Get
(
"codex_cli_simplified_flow"
))
session
,
ok
:=
svc
.
sessionStore
.
Get
(
result
.
SessionID
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
ClientID
,
session
.
ClientID
)
}
backend/internal/service/openai_oauth_service_sora_session_test.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type
openaiOAuthClientNoopStub
struct
{}
func
(
s
*
openaiOAuthClientNoopStub
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientNoopStub
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientNoopStub
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_Success
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=st-token"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
"st-token"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
info
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
require
.
Equal
(
t
,
"demo@example.com"
,
info
.
Email
)
require
.
Greater
(
t
,
info
.
ExpiresAt
,
int64
(
0
))
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"expires":"2099-01-01T00:00:00Z"}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
_
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
"st-token"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"missing access token"
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=st-cookie-value"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
"__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=chunk-0chunk-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=new-0new-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=ok-0ok-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"set-cookie"
,
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/"
,
"set-cookie"
,
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/"
,
"set-cookie"
,
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
backend/internal/service/openai_token_provider.go
View file @
d757df8a
...
...
@@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
// OpenAITokenCache token cache interface.
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider manages access_token for OpenAI
/Sora
OAuth accounts.
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
...
...
@@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
(
account
.
Platform
!=
PlatformOpenAI
&&
account
.
Platform
!=
PlatformSora
)
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai
/sora
oauth account"
)
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
...
...
@@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
p
.
metrics
.
refreshRequests
.
Add
(
1
)
p
.
metrics
.
touchNow
()
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if
account
.
Platform
==
PlatformSora
{
slog
.
Debug
(
"openai_token_refresh_skipped_for_sora"
,
"account_id"
,
account
.
ID
)
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
{
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
...
...
backend/internal/service/openai_token_provider_test.go
View file @
d757df8a
...
...
@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai
/sora
oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
...
...
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai
/sora
oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
...
...
backend/internal/service/setting_service.go
View file @
d757df8a
...
...
@@ -22,8 +22,6 @@ import (
var
(
ErrRegistrationDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrSettingNotFound
=
infraerrors
.
NotFound
(
"SETTING_NOT_FOUND"
,
"setting not found"
)
ErrSoraS3ProfileNotFound
=
infraerrors
.
NotFound
(
"SORA_S3_PROFILE_NOT_FOUND"
,
"sora s3 profile not found"
)
ErrSoraS3ProfileExists
=
infraerrors
.
Conflict
(
"SORA_S3_PROFILE_EXISTS"
,
"sora s3 profile already exists"
)
ErrDefaultSubGroupInvalid
=
infraerrors
.
BadRequest
(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID"
,
"default subscription group must exist and be subscription type"
,
...
...
@@ -104,7 +102,6 @@ type SettingService struct {
defaultSubGroupReader
DefaultSubscriptionGroupReader
cfg
*
config
.
Config
onUpdate
func
()
// Callback when settings are updated (for cache invalidation)
onS3Update
func
()
// Callback when Sora S3 settings are updated
version
string
// Application version
}
...
...
@@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton
,
SettingKeyPurchaseSubscriptionEnabled
,
SettingKeyPurchaseSubscriptionURL
,
SettingKeySoraClientEnabled
,
SettingKeyCustomMenuItems
,
SettingKeyCustomEndpoints
,
SettingKeyLinuxDoConnectEnabled
,
...
...
@@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
...
...
@@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s
.
onUpdate
=
callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func
(
s
*
SettingService
)
SetOnS3UpdateCallback
(
callback
func
())
{
s
.
onS3Update
=
callback
}
// SetVersion sets the application version for injection into public settings
func
(
s
*
SettingService
)
SetVersion
(
version
string
)
{
s
.
version
=
version
...
...
@@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url,omitempty"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
CustomMenuItems
json
.
RawMessage
`json:"custom_menu_items"`
CustomEndpoints
json
.
RawMessage
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
...
...
@@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
SoraClientEnabled
:
settings
.
SoraClientEnabled
,
CustomMenuItems
:
filterUserVisibleMenuItems
(
settings
.
CustomMenuItems
),
CustomEndpoints
:
safeRawJSONArray
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
...
...
@@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
updates
[
SettingKeyPurchaseSubscriptionEnabled
]
=
strconv
.
FormatBool
(
settings
.
PurchaseSubscriptionEnabled
)
updates
[
SettingKeyPurchaseSubscriptionURL
]
=
strings
.
TrimSpace
(
settings
.
PurchaseSubscriptionURL
)
updates
[
SettingKeySoraClientEnabled
]
=
strconv
.
FormatBool
(
settings
.
SoraClientEnabled
)
updates
[
SettingKeyCustomMenuItems
]
=
settings
.
CustomMenuItems
updates
[
SettingKeyCustomEndpoints
]
=
settings
.
CustomEndpoints
...
...
@@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo
:
""
,
SettingKeyPurchaseSubscriptionEnabled
:
"false"
,
SettingKeyPurchaseSubscriptionURL
:
""
,
SettingKeySoraClientEnabled
:
"false"
,
SettingKeyCustomMenuItems
:
"[]"
,
SettingKeyCustomEndpoints
:
"[]"
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
...
...
@@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
PurchaseSubscriptionEnabled
:
settings
[
SettingKeyPurchaseSubscriptionEnabled
]
==
"true"
,
PurchaseSubscriptionURL
:
strings
.
TrimSpace
(
settings
[
SettingKeyPurchaseSubscriptionURL
]),
SoraClientEnabled
:
settings
[
SettingKeySoraClientEnabled
]
==
"true"
,
CustomMenuItems
:
settings
[
SettingKeyCustomMenuItems
],
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
BackendModeEnabled
:
settings
[
SettingKeyBackendModeEnabled
]
==
"true"
,
...
...
@@ -1583,607 +1568,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyStreamTimeoutSettings
,
string
(
data
))
}
type
soraS3ProfilesStore
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
soraS3ProfileStoreItem
`json:"items"`
}
type
soraS3ProfileStoreItem
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func
(
s
*
SettingService
)
GetSoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
profiles
,
err
:=
s
.
ListSoraS3Profiles
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
activeProfile
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
activeProfile
==
nil
{
return
&
SoraS3Settings
{},
nil
}
return
&
SoraS3Settings
{
Enabled
:
activeProfile
.
Enabled
,
Endpoint
:
activeProfile
.
Endpoint
,
Region
:
activeProfile
.
Region
,
Bucket
:
activeProfile
.
Bucket
,
AccessKeyID
:
activeProfile
.
AccessKeyID
,
SecretAccessKey
:
activeProfile
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
activeProfile
.
SecretAccessKeyConfigured
,
Prefix
:
activeProfile
.
Prefix
,
ForcePathStyle
:
activeProfile
.
ForcePathStyle
,
CDNURL
:
activeProfile
.
CDNURL
,
DefaultStorageQuotaBytes
:
activeProfile
.
DefaultStorageQuotaBytes
,
},
nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func
(
s
*
SettingService
)
SetSoraS3Settings
(
ctx
context
.
Context
,
settings
*
SoraS3Settings
)
error
{
if
settings
==
nil
{
return
fmt
.
Errorf
(
"settings cannot be nil"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
activeIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
store
.
ActiveProfileID
)
if
activeIndex
<
0
{
activeID
:=
"default"
if
hasSoraS3ProfileID
(
store
.
Items
,
activeID
)
{
activeID
=
fmt
.
Sprintf
(
"default-%d"
,
time
.
Now
()
.
Unix
())
}
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
activeID
,
Name
:
"Default"
,
UpdatedAt
:
now
,
})
store
.
ActiveProfileID
=
activeID
activeIndex
=
len
(
store
.
Items
)
-
1
}
active
:=
store
.
Items
[
activeIndex
]
active
.
Enabled
=
settings
.
Enabled
active
.
Endpoint
=
strings
.
TrimSpace
(
settings
.
Endpoint
)
active
.
Region
=
strings
.
TrimSpace
(
settings
.
Region
)
active
.
Bucket
=
strings
.
TrimSpace
(
settings
.
Bucket
)
active
.
AccessKeyID
=
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
active
.
Prefix
=
strings
.
TrimSpace
(
settings
.
Prefix
)
active
.
ForcePathStyle
=
settings
.
ForcePathStyle
active
.
CDNURL
=
strings
.
TrimSpace
(
settings
.
CDNURL
)
active
.
DefaultStorageQuotaBytes
=
maxInt64
(
settings
.
DefaultStorageQuotaBytes
,
0
)
if
settings
.
SecretAccessKey
!=
""
{
active
.
SecretAccessKey
=
settings
.
SecretAccessKey
}
active
.
UpdatedAt
=
now
store
.
Items
[
activeIndex
]
=
active
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func
(
s
*
SettingService
)
ListSoraS3Profiles
(
ctx
context
.
Context
)
(
*
SoraS3ProfileList
,
error
)
{
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
convertSoraS3ProfilesStore
(
store
),
nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func
(
s
*
SettingService
)
CreateSoraS3Profile
(
ctx
context
.
Context
,
profile
*
SoraS3Profile
,
setActive
bool
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
profileID
:=
strings
.
TrimSpace
(
profile
.
ProfileID
)
if
profileID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
if
hasSoraS3ProfileID
(
store
.
Items
,
profileID
)
{
return
nil
,
ErrSoraS3ProfileExists
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
=
append
(
store
.
Items
,
soraS3ProfileStoreItem
{
ProfileID
:
profileID
,
Name
:
name
,
Enabled
:
profile
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
profile
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
profile
.
Region
),
Bucket
:
strings
.
TrimSpace
(
profile
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
profile
.
AccessKeyID
),
SecretAccessKey
:
profile
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
profile
.
Prefix
),
ForcePathStyle
:
profile
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
profile
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
})
if
setActive
||
store
.
ActiveProfileID
==
""
{
store
.
ActiveProfileID
=
profileID
}
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
created
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
profileID
)
if
created
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
created
,
nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func
(
s
*
SettingService
)
UpdateSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
,
profile
*
SoraS3Profile
)
(
*
SoraS3Profile
,
error
)
{
if
profile
==
nil
{
return
nil
,
fmt
.
Errorf
(
"profile cannot be nil"
)
}
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
target
:=
store
.
Items
[
targetIndex
]
name
:=
strings
.
TrimSpace
(
profile
.
Name
)
if
name
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_NAME_REQUIRED"
,
"name is required"
)
}
target
.
Name
=
name
target
.
Enabled
=
profile
.
Enabled
target
.
Endpoint
=
strings
.
TrimSpace
(
profile
.
Endpoint
)
target
.
Region
=
strings
.
TrimSpace
(
profile
.
Region
)
target
.
Bucket
=
strings
.
TrimSpace
(
profile
.
Bucket
)
target
.
AccessKeyID
=
strings
.
TrimSpace
(
profile
.
AccessKeyID
)
target
.
Prefix
=
strings
.
TrimSpace
(
profile
.
Prefix
)
target
.
ForcePathStyle
=
profile
.
ForcePathStyle
target
.
CDNURL
=
strings
.
TrimSpace
(
profile
.
CDNURL
)
target
.
DefaultStorageQuotaBytes
=
maxInt64
(
profile
.
DefaultStorageQuotaBytes
,
0
)
if
profile
.
SecretAccessKey
!=
""
{
target
.
SecretAccessKey
=
profile
.
SecretAccessKey
}
target
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
store
.
Items
[
targetIndex
]
=
target
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
updated
:=
findSoraS3ProfileByID
(
profiles
.
Items
,
targetID
)
if
updated
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
updated
,
nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func
(
s
*
SettingService
)
DeleteSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
error
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
ErrSoraS3ProfileNotFound
}
store
.
Items
=
append
(
store
.
Items
[
:
targetIndex
],
store
.
Items
[
targetIndex
+
1
:
]
...
)
if
store
.
ActiveProfileID
==
targetID
{
store
.
ActiveProfileID
=
""
if
len
(
store
.
Items
)
>
0
{
store
.
ActiveProfileID
=
store
.
Items
[
0
]
.
ProfileID
}
}
return
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func
(
s
*
SettingService
)
SetActiveSoraS3Profile
(
ctx
context
.
Context
,
profileID
string
)
(
*
SoraS3Profile
,
error
)
{
targetID
:=
strings
.
TrimSpace
(
profileID
)
if
targetID
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"SORA_S3_PROFILE_ID_REQUIRED"
,
"profile_id is required"
)
}
store
,
err
:=
s
.
loadSoraS3ProfilesStore
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
targetIndex
:=
findSoraS3ProfileIndex
(
store
.
Items
,
targetID
)
if
targetIndex
<
0
{
return
nil
,
ErrSoraS3ProfileNotFound
}
store
.
ActiveProfileID
=
targetID
store
.
Items
[
targetIndex
]
.
UpdatedAt
=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
persistSoraS3ProfilesStore
(
ctx
,
store
);
err
!=
nil
{
return
nil
,
err
}
profiles
:=
convertSoraS3ProfilesStore
(
store
)
active
:=
pickActiveSoraS3Profile
(
profiles
.
Items
,
profiles
.
ActiveProfileID
)
if
active
==
nil
{
return
nil
,
ErrSoraS3ProfileNotFound
}
return
active
,
nil
}
func
(
s
*
SettingService
)
loadSoraS3ProfilesStore
(
ctx
context
.
Context
)
(
*
soraS3ProfilesStore
,
error
)
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySoraS3Profiles
)
if
err
==
nil
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
&
soraS3ProfilesStore
{},
nil
}
var
store
soraS3ProfilesStore
if
unmarshalErr
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
store
);
unmarshalErr
!=
nil
{
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal sora s3 profiles: %w"
,
unmarshalErr
)
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
normalized
:=
normalizeSoraS3ProfilesStore
(
store
)
return
&
normalized
,
nil
}
if
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"get sora s3 profiles: %w"
,
err
)
}
legacy
,
legacyErr
:=
s
.
getLegacySoraS3Settings
(
ctx
)
if
legacyErr
!=
nil
{
return
nil
,
legacyErr
}
if
isEmptyLegacySoraS3Settings
(
legacy
)
{
return
&
soraS3ProfilesStore
{},
nil
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
return
&
soraS3ProfilesStore
{
ActiveProfileID
:
"default"
,
Items
:
[]
soraS3ProfileStoreItem
{
{
ProfileID
:
"default"
,
Name
:
"Default"
,
Enabled
:
legacy
.
Enabled
,
Endpoint
:
strings
.
TrimSpace
(
legacy
.
Endpoint
),
Region
:
strings
.
TrimSpace
(
legacy
.
Region
),
Bucket
:
strings
.
TrimSpace
(
legacy
.
Bucket
),
AccessKeyID
:
strings
.
TrimSpace
(
legacy
.
AccessKeyID
),
SecretAccessKey
:
legacy
.
SecretAccessKey
,
Prefix
:
strings
.
TrimSpace
(
legacy
.
Prefix
),
ForcePathStyle
:
legacy
.
ForcePathStyle
,
CDNURL
:
strings
.
TrimSpace
(
legacy
.
CDNURL
),
DefaultStorageQuotaBytes
:
maxInt64
(
legacy
.
DefaultStorageQuotaBytes
,
0
),
UpdatedAt
:
now
,
},
},
},
nil
}
func
(
s
*
SettingService
)
persistSoraS3ProfilesStore
(
ctx
context
.
Context
,
store
*
soraS3ProfilesStore
)
error
{
if
store
==
nil
{
return
fmt
.
Errorf
(
"sora s3 profiles store cannot be nil"
)
}
normalized
:=
normalizeSoraS3ProfilesStore
(
*
store
)
data
,
err
:=
json
.
Marshal
(
normalized
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal sora s3 profiles: %w"
,
err
)
}
updates
:=
map
[
string
]
string
{
SettingKeySoraS3Profiles
:
string
(
data
),
}
active
:=
pickActiveSoraS3ProfileFromStore
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
if
active
==
nil
{
updates
[
SettingKeySoraS3Enabled
]
=
"false"
updates
[
SettingKeySoraS3Endpoint
]
=
""
updates
[
SettingKeySoraS3Region
]
=
""
updates
[
SettingKeySoraS3Bucket
]
=
""
updates
[
SettingKeySoraS3AccessKeyID
]
=
""
updates
[
SettingKeySoraS3Prefix
]
=
""
updates
[
SettingKeySoraS3ForcePathStyle
]
=
"false"
updates
[
SettingKeySoraS3CDNURL
]
=
""
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
"0"
updates
[
SettingKeySoraS3SecretAccessKey
]
=
""
}
else
{
updates
[
SettingKeySoraS3Enabled
]
=
strconv
.
FormatBool
(
active
.
Enabled
)
updates
[
SettingKeySoraS3Endpoint
]
=
strings
.
TrimSpace
(
active
.
Endpoint
)
updates
[
SettingKeySoraS3Region
]
=
strings
.
TrimSpace
(
active
.
Region
)
updates
[
SettingKeySoraS3Bucket
]
=
strings
.
TrimSpace
(
active
.
Bucket
)
updates
[
SettingKeySoraS3AccessKeyID
]
=
strings
.
TrimSpace
(
active
.
AccessKeyID
)
updates
[
SettingKeySoraS3Prefix
]
=
strings
.
TrimSpace
(
active
.
Prefix
)
updates
[
SettingKeySoraS3ForcePathStyle
]
=
strconv
.
FormatBool
(
active
.
ForcePathStyle
)
updates
[
SettingKeySoraS3CDNURL
]
=
strings
.
TrimSpace
(
active
.
CDNURL
)
updates
[
SettingKeySoraDefaultStorageQuotaBytes
]
=
strconv
.
FormatInt
(
maxInt64
(
active
.
DefaultStorageQuotaBytes
,
0
),
10
)
updates
[
SettingKeySoraS3SecretAccessKey
]
=
active
.
SecretAccessKey
}
if
err
:=
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
);
err
!=
nil
{
return
err
}
if
s
.
onUpdate
!=
nil
{
s
.
onUpdate
()
}
if
s
.
onS3Update
!=
nil
{
s
.
onS3Update
()
}
return
nil
}
func
(
s
*
SettingService
)
getLegacySoraS3Settings
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
keys
:=
[]
string
{
SettingKeySoraS3Enabled
,
SettingKeySoraS3Endpoint
,
SettingKeySoraS3Region
,
SettingKeySoraS3Bucket
,
SettingKeySoraS3AccessKeyID
,
SettingKeySoraS3SecretAccessKey
,
SettingKeySoraS3Prefix
,
SettingKeySoraS3ForcePathStyle
,
SettingKeySoraS3CDNURL
,
SettingKeySoraDefaultStorageQuotaBytes
,
}
values
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get legacy sora s3 settings: %w"
,
err
)
}
result
:=
&
SoraS3Settings
{
Enabled
:
values
[
SettingKeySoraS3Enabled
]
==
"true"
,
Endpoint
:
values
[
SettingKeySoraS3Endpoint
],
Region
:
values
[
SettingKeySoraS3Region
],
Bucket
:
values
[
SettingKeySoraS3Bucket
],
AccessKeyID
:
values
[
SettingKeySoraS3AccessKeyID
],
SecretAccessKey
:
values
[
SettingKeySoraS3SecretAccessKey
],
SecretAccessKeyConfigured
:
values
[
SettingKeySoraS3SecretAccessKey
]
!=
""
,
Prefix
:
values
[
SettingKeySoraS3Prefix
],
ForcePathStyle
:
values
[
SettingKeySoraS3ForcePathStyle
]
==
"true"
,
CDNURL
:
values
[
SettingKeySoraS3CDNURL
],
}
if
v
,
parseErr
:=
strconv
.
ParseInt
(
values
[
SettingKeySoraDefaultStorageQuotaBytes
],
10
,
64
);
parseErr
==
nil
{
result
.
DefaultStorageQuotaBytes
=
v
}
return
result
,
nil
}
func
normalizeSoraS3ProfilesStore
(
store
soraS3ProfilesStore
)
soraS3ProfilesStore
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
store
.
Items
))
normalized
:=
soraS3ProfilesStore
{
ActiveProfileID
:
strings
.
TrimSpace
(
store
.
ActiveProfileID
),
Items
:
make
([]
soraS3ProfileStoreItem
,
0
,
len
(
store
.
Items
)),
}
now
:=
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
)
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
item
.
ProfileID
=
strings
.
TrimSpace
(
item
.
ProfileID
)
if
item
.
ProfileID
==
""
{
item
.
ProfileID
=
fmt
.
Sprintf
(
"profile-%d"
,
idx
+
1
)
}
if
_
,
exists
:=
seen
[
item
.
ProfileID
];
exists
{
continue
}
seen
[
item
.
ProfileID
]
=
struct
{}{}
item
.
Name
=
strings
.
TrimSpace
(
item
.
Name
)
if
item
.
Name
==
""
{
item
.
Name
=
item
.
ProfileID
}
item
.
Endpoint
=
strings
.
TrimSpace
(
item
.
Endpoint
)
item
.
Region
=
strings
.
TrimSpace
(
item
.
Region
)
item
.
Bucket
=
strings
.
TrimSpace
(
item
.
Bucket
)
item
.
AccessKeyID
=
strings
.
TrimSpace
(
item
.
AccessKeyID
)
item
.
Prefix
=
strings
.
TrimSpace
(
item
.
Prefix
)
item
.
CDNURL
=
strings
.
TrimSpace
(
item
.
CDNURL
)
item
.
DefaultStorageQuotaBytes
=
maxInt64
(
item
.
DefaultStorageQuotaBytes
,
0
)
item
.
UpdatedAt
=
strings
.
TrimSpace
(
item
.
UpdatedAt
)
if
item
.
UpdatedAt
==
""
{
item
.
UpdatedAt
=
now
}
normalized
.
Items
=
append
(
normalized
.
Items
,
item
)
}
if
len
(
normalized
.
Items
)
==
0
{
normalized
.
ActiveProfileID
=
""
return
normalized
}
if
findSoraS3ProfileIndex
(
normalized
.
Items
,
normalized
.
ActiveProfileID
)
>=
0
{
return
normalized
}
normalized
.
ActiveProfileID
=
normalized
.
Items
[
0
]
.
ProfileID
return
normalized
}
func
convertSoraS3ProfilesStore
(
store
*
soraS3ProfilesStore
)
*
SoraS3ProfileList
{
if
store
==
nil
{
return
&
SoraS3ProfileList
{}
}
items
:=
make
([]
SoraS3Profile
,
0
,
len
(
store
.
Items
))
for
idx
:=
range
store
.
Items
{
item
:=
store
.
Items
[
idx
]
items
=
append
(
items
,
SoraS3Profile
{
ProfileID
:
item
.
ProfileID
,
Name
:
item
.
Name
,
IsActive
:
item
.
ProfileID
==
store
.
ActiveProfileID
,
Enabled
:
item
.
Enabled
,
Endpoint
:
item
.
Endpoint
,
Region
:
item
.
Region
,
Bucket
:
item
.
Bucket
,
AccessKeyID
:
item
.
AccessKeyID
,
SecretAccessKey
:
item
.
SecretAccessKey
,
SecretAccessKeyConfigured
:
item
.
SecretAccessKey
!=
""
,
Prefix
:
item
.
Prefix
,
ForcePathStyle
:
item
.
ForcePathStyle
,
CDNURL
:
item
.
CDNURL
,
DefaultStorageQuotaBytes
:
item
.
DefaultStorageQuotaBytes
,
UpdatedAt
:
item
.
UpdatedAt
,
})
}
return
&
SoraS3ProfileList
{
ActiveProfileID
:
store
.
ActiveProfileID
,
Items
:
items
,
}
}
func
pickActiveSoraS3Profile
(
items
[]
SoraS3Profile
,
activeProfileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileByID
(
items
[]
SoraS3Profile
,
profileID
string
)
*
SoraS3Profile
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
&
items
[
idx
]
}
}
return
nil
}
func
pickActiveSoraS3ProfileFromStore
(
items
[]
soraS3ProfileStoreItem
,
activeProfileID
string
)
*
soraS3ProfileStoreItem
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
activeProfileID
{
return
&
items
[
idx
]
}
}
if
len
(
items
)
==
0
{
return
nil
}
return
&
items
[
0
]
}
func
findSoraS3ProfileIndex
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
int
{
for
idx
:=
range
items
{
if
items
[
idx
]
.
ProfileID
==
profileID
{
return
idx
}
}
return
-
1
}
func
hasSoraS3ProfileID
(
items
[]
soraS3ProfileStoreItem
,
profileID
string
)
bool
{
return
findSoraS3ProfileIndex
(
items
,
profileID
)
>=
0
}
func
isEmptyLegacySoraS3Settings
(
settings
*
SoraS3Settings
)
bool
{
if
settings
==
nil
{
return
true
}
if
settings
.
Enabled
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Endpoint
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Region
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Bucket
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
AccessKeyID
)
!=
""
{
return
false
}
if
settings
.
SecretAccessKey
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
Prefix
)
!=
""
{
return
false
}
if
strings
.
TrimSpace
(
settings
.
CDNURL
)
!=
""
{
return
false
}
return
settings
.
DefaultStorageQuotaBytes
==
0
}
func
maxInt64
(
value
int64
,
min
int64
)
int64
{
if
value
<
min
{
return
min
}
return
value
}
backend/internal/service/settings_view.go
View file @
d757df8a
...
...
@@ -41,7 +41,6 @@ type SystemSettings struct {
HideCcsImportButton
bool
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
CustomMenuItems
string
// JSON array of custom menu items
CustomEndpoints
string
// JSON array of custom endpoints
...
...
@@ -107,7 +106,6 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled
bool
PurchaseSubscriptionURL
string
SoraClientEnabled
bool
CustomMenuItems
string
// JSON array of custom menu items
CustomEndpoints
string
// JSON array of custom endpoints
...
...
@@ -116,46 +114,6 @@ type PublicSettings struct {
Version
string
}
// SoraS3Settings Sora S3 存储配置
type
SoraS3Settings
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type
SoraS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"-"`
// 仅内部使用,不直接返回前端
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
// 前端展示用
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type
SoraS3ProfileList
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
SoraS3Profile
`json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type
StreamTimeoutSettings
struct
{
// Enabled 是否启用流超时处理
...
...
backend/internal/service/sora_account_service.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
"context"
// SoraAccountRepository Sora 账号扩展表仓储接口
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
//
// 设计说明:
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
// - Token 刷新时需要同时更新两个表以保持数据一致性
type
SoraAccountRepository
interface
{
// Upsert 创建或更新 Sora 账号扩展信息
// accountID: 关联的 accounts.id
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
//
// 如果记录不存在则创建,存在则更新。
// 用于:
// 1. 创建 Sora 账号时初始化扩展表
// 2. Token 刷新时同步更新扩展表
Upsert
(
ctx
context
.
Context
,
accountID
int64
,
updates
map
[
string
]
any
)
error
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
// 返回 nil, nil 表示记录不存在(非错误)
GetByAccountID
(
ctx
context
.
Context
,
accountID
int64
)
(
*
SoraAccount
,
error
)
// Delete 删除 Sora 账号扩展信息
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
Delete
(
ctx
context
.
Context
,
accountID
int64
)
error
}
// SoraAccount Sora 账号扩展信息
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
type
SoraAccount
struct
{
AccountID
int64
// 关联的 accounts.id
AccessToken
string
// OAuth access_token
RefreshToken
string
// OAuth refresh_token
SessionToken
string
// Session token(可选,用于 ST→AT 兜底)
}
backend/internal/service/sora_client.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"fmt"
"net/http"
)
// SoraClient 定义直连 Sora 的任务操作接口。
type
SoraClient
interface
{
Enabled
()
bool
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
}
// SoraImageRequest 图片生成请求参数
type
SoraImageRequest
struct
{
Prompt
string
Width
int
Height
int
MediaID
string
}
// SoraVideoRequest 视频生成请求参数
type
SoraVideoRequest
struct
{
Prompt
string
Orientation
string
Frames
int
Model
string
Size
string
VideoCount
int
MediaID
string
RemixTargetID
string
CameoIDs
[]
string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type
SoraStoryboardRequest
struct
{
Prompt
string
Orientation
string
Frames
int
Model
string
Size
string
MediaID
string
}
// SoraImageTaskStatus 图片任务状态
type
SoraImageTaskStatus
struct
{
ID
string
Status
string
ProgressPct
float64
URLs
[]
string
ErrorMsg
string
}
// SoraVideoTaskStatus 视频任务状态
type
SoraVideoTaskStatus
struct
{
ID
string
Status
string
ProgressPct
int
URLs
[]
string
GenerationID
string
ErrorMsg
string
}
// SoraCameoStatus 角色处理中间态
type
SoraCameoStatus
struct
{
Status
string
StatusMessage
string
DisplayNameHint
string
UsernameHint
string
ProfileAssetURL
string
InstructionSetHint
any
InstructionSet
any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type
SoraCharacterFinalizeRequest
struct
{
CameoID
string
Username
string
DisplayName
string
ProfileAssetPointer
string
InstructionSet
any
}
// SoraUpstreamError 上游错误
type
SoraUpstreamError
struct
{
StatusCode
int
Message
string
Headers
http
.
Header
Body
[]
byte
}
func
(
e
*
SoraUpstreamError
)
Error
()
string
{
if
e
==
nil
{
return
"sora upstream error"
}
if
e
.
Message
!=
""
{
return
fmt
.
Sprintf
(
"sora upstream error: %d %s"
,
e
.
StatusCode
,
e
.
Message
)
}
return
fmt
.
Sprintf
(
"sora upstream error: %d"
,
e
.
StatusCode
)
}
backend/internal/service/sora_gateway_service.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
"math/rand"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
const
soraImageInputMaxBytes
=
20
<<
20
const
soraImageInputMaxRedirects
=
3
const
soraImageInputTimeout
=
20
*
time
.
Second
const
soraVideoInputMaxBytes
=
200
<<
20
const
soraVideoInputMaxRedirects
=
3
const
soraVideoInputTimeout
=
60
*
time
.
Second
var
soraImageSizeMap
=
map
[
string
]
string
{
"gpt-image"
:
"360"
,
"gpt-image-landscape"
:
"540"
,
"gpt-image-portrait"
:
"540"
,
}
var
soraBlockedHostnames
=
map
[
string
]
struct
{}{
"localhost"
:
{},
"localhost.localdomain"
:
{},
"metadata.google.internal"
:
{},
"metadata.google.internal."
:
{},
}
var
soraBlockedCIDRs
=
mustParseCIDRs
([]
string
{
"0.0.0.0/8"
,
"10.0.0.0/8"
,
"100.64.0.0/10"
,
"127.0.0.0/8"
,
"169.254.0.0/16"
,
"172.16.0.0/12"
,
"192.168.0.0/16"
,
"224.0.0.0/4"
,
"240.0.0.0/4"
,
"::/128"
,
"::1/128"
,
"fc00::/7"
,
"fe80::/10"
,
})
// SoraGatewayService handles forwarding requests to Sora upstream.
type
SoraGatewayService
struct
{
soraClient
SoraClient
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
// 用于 apikey 类型账号的 HTTP 透传
cfg
*
config
.
Config
}
type
soraWatermarkOptions
struct
{
Enabled
bool
ParseMethod
string
ParseURL
string
ParseToken
string
FallbackOnFailure
bool
DeletePost
bool
}
type
soraCharacterOptions
struct
{
SetPublic
bool
DeleteAfterGenerate
bool
}
type
soraCharacterFlowResult
struct
{
CameoID
string
CharacterID
string
Username
string
DisplayName
string
}
var
soraStoryboardPattern
=
regexp
.
MustCompile
(
`\[\d+(?:\.\d+)?s\]`
)
var
soraStoryboardShotPattern
=
regexp
.
MustCompile
(
`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`
)
var
soraRemixTargetPattern
=
regexp
.
MustCompile
(
`s_[a-f0-9]{32}`
)
var
soraRemixTargetInURLPattern
=
regexp
.
MustCompile
(
`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`
)
type
soraPreflightChecker
interface
{
PreflightCheck
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
,
modelCfg
SoraModelConfig
)
error
}
func
NewSoraGatewayService
(
soraClient
SoraClient
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
cfg
*
config
.
Config
,
)
*
SoraGatewayService
{
return
&
SoraGatewayService
{
soraClient
:
soraClient
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
cfg
:
cfg
,
}
}
func
(
s
*
SoraGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
clientStream
bool
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
if
account
.
Type
==
AccountTypeAPIKey
&&
account
.
GetBaseURL
()
!=
""
{
if
s
.
httpUpstream
==
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"HTTP upstream client not configured"
,
clientStream
)
return
nil
,
errors
.
New
(
"httpUpstream not configured for sora apikey forwarding"
)
}
return
s
.
forwardToUpstream
(
ctx
,
c
,
account
,
body
,
clientStream
,
startTime
)
}
if
s
.
soraClient
==
nil
||
!
s
.
soraClient
.
Enabled
()
{
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"api_error"
,
"message"
:
"Sora 上游未配置"
,
},
})
}
return
nil
,
errors
.
New
(
"sora upstream not configured"
)
}
var
reqBody
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
reqBody
);
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"parse request: %w"
,
err
)
}
reqModel
,
_
:=
reqBody
[
"model"
]
.
(
string
)
reqStream
,
_
:=
reqBody
[
"stream"
]
.
(
bool
)
if
strings
.
TrimSpace
(
reqModel
)
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"model is required"
)
}
originalModel
:=
reqModel
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
var
upstreamModel
string
if
mappedModel
!=
""
&&
mappedModel
!=
reqModel
{
reqModel
=
mappedModel
upstreamModel
=
mappedModel
}
modelCfg
,
ok
:=
GetSoraModelConfig
(
reqModel
)
if
!
ok
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Unsupported Sora model"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"unsupported model: %s"
,
reqModel
)
}
prompt
,
imageInput
,
videoInput
,
remixTargetID
:=
extractSoraInput
(
reqBody
)
prompt
=
strings
.
TrimSpace
(
prompt
)
imageInput
=
strings
.
TrimSpace
(
imageInput
)
videoInput
=
strings
.
TrimSpace
(
videoInput
)
remixTargetID
=
strings
.
TrimSpace
(
remixTargetID
)
if
videoInput
!=
""
&&
modelCfg
.
Type
!=
"video"
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"video input only supports video models"
,
clientStream
)
return
nil
,
errors
.
New
(
"video input only supports video models"
)
}
if
videoInput
!=
""
&&
imageInput
!=
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"image input and video input cannot be used together"
,
clientStream
)
return
nil
,
errors
.
New
(
"image input and video input cannot be used together"
)
}
characterOnly
:=
videoInput
!=
""
&&
prompt
==
""
if
modelCfg
.
Type
==
"prompt_enhance"
&&
prompt
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"prompt is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"prompt is required"
)
}
if
modelCfg
.
Type
!=
"prompt_enhance"
&&
prompt
==
""
&&
!
characterOnly
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"prompt is required"
,
clientStream
)
return
nil
,
errors
.
New
(
"prompt is required"
)
}
reqCtx
,
cancel
:=
s
.
withSoraTimeout
(
ctx
,
reqStream
)
if
cancel
!=
nil
{
defer
cancel
()
}
if
checker
,
ok
:=
s
.
soraClient
.
(
soraPreflightChecker
);
ok
&&
!
characterOnly
{
if
err
:=
checker
.
PreflightCheck
(
reqCtx
,
account
,
reqModel
,
modelCfg
);
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
}
if
modelCfg
.
Type
==
"prompt_enhance"
{
enhancedPrompt
,
err
:=
s
.
soraClient
.
EnhancePrompt
(
reqCtx
,
account
,
prompt
,
modelCfg
.
ExpansionLevel
,
modelCfg
.
DurationS
)
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
content
:=
strings
.
TrimSpace
(
enhancedPrompt
)
if
content
==
""
{
content
=
prompt
}
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusOK
,
buildSoraNonStreamResponse
(
content
,
reqModel
))
}
return
&
ForwardResult
{
RequestID
:
""
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
"prompt"
,
},
nil
}
characterOpts
:=
parseSoraCharacterOptions
(
reqBody
)
watermarkOpts
:=
parseSoraWatermarkOptions
(
reqBody
)
var
characterResult
*
soraCharacterFlowResult
if
videoInput
!=
""
{
videoData
,
videoErr
:=
decodeSoraVideoInput
(
reqCtx
,
videoInput
)
if
videoErr
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
videoErr
.
Error
(),
clientStream
)
return
nil
,
videoErr
}
characterResult
,
videoErr
=
s
.
createCharacterFromVideo
(
reqCtx
,
account
,
videoData
,
characterOpts
)
if
videoErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
videoErr
,
reqModel
,
c
,
clientStream
)
}
if
characterResult
!=
nil
&&
characterOpts
.
DeleteAfterGenerate
&&
strings
.
TrimSpace
(
characterResult
.
CharacterID
)
!=
""
&&
!
characterOnly
{
characterID
:=
strings
.
TrimSpace
(
characterResult
.
CharacterID
)
defer
func
()
{
cleanupCtx
,
cancelCleanup
:=
context
.
WithTimeout
(
context
.
Background
(),
15
*
time
.
Second
)
defer
cancelCleanup
()
if
err
:=
s
.
soraClient
.
DeleteCharacter
(
cleanupCtx
,
account
,
characterID
);
err
!=
nil
{
log
.
Printf
(
"[Sora] cleanup character failed, character_id=%s err=%v"
,
characterID
,
err
)
}
}()
}
if
characterOnly
{
content
:=
"角色创建成功"
if
characterResult
!=
nil
&&
strings
.
TrimSpace
(
characterResult
.
Username
)
!=
""
{
content
=
fmt
.
Sprintf
(
"角色创建成功,角色名@%s"
,
strings
.
TrimSpace
(
characterResult
.
Username
))
}
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
resp
:=
buildSoraNonStreamResponse
(
content
,
reqModel
)
if
characterResult
!=
nil
{
resp
[
"character_id"
]
=
characterResult
.
CharacterID
resp
[
"cameo_id"
]
=
characterResult
.
CameoID
resp
[
"character_username"
]
=
characterResult
.
Username
resp
[
"character_display_name"
]
=
characterResult
.
DisplayName
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
}
return
&
ForwardResult
{
RequestID
:
""
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
"prompt"
,
},
nil
}
if
characterResult
!=
nil
&&
strings
.
TrimSpace
(
characterResult
.
Username
)
!=
""
{
prompt
=
fmt
.
Sprintf
(
"@%s %s"
,
characterResult
.
Username
,
prompt
)
}
}
var
imageData
[]
byte
imageFilename
:=
""
if
imageInput
!=
""
{
decoded
,
filename
,
err
:=
decodeSoraImageInput
(
reqCtx
,
imageInput
)
if
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
(),
clientStream
)
return
nil
,
err
}
imageData
=
decoded
imageFilename
=
filename
}
mediaID
:=
""
if
len
(
imageData
)
>
0
{
uploadID
,
err
:=
s
.
soraClient
.
UploadImage
(
reqCtx
,
account
,
imageData
,
imageFilename
)
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
mediaID
=
uploadID
}
taskID
:=
""
var
err
error
videoCount
:=
parseSoraVideoCount
(
reqBody
)
switch
modelCfg
.
Type
{
case
"image"
:
taskID
,
err
=
s
.
soraClient
.
CreateImageTask
(
reqCtx
,
account
,
SoraImageRequest
{
Prompt
:
prompt
,
Width
:
modelCfg
.
Width
,
Height
:
modelCfg
.
Height
,
MediaID
:
mediaID
,
})
case
"video"
:
if
remixTargetID
==
""
&&
isSoraStoryboardPrompt
(
prompt
)
{
taskID
,
err
=
s
.
soraClient
.
CreateStoryboardTask
(
reqCtx
,
account
,
SoraStoryboardRequest
{
Prompt
:
formatSoraStoryboardPrompt
(
prompt
),
Orientation
:
modelCfg
.
Orientation
,
Frames
:
modelCfg
.
Frames
,
Model
:
modelCfg
.
Model
,
Size
:
modelCfg
.
Size
,
MediaID
:
mediaID
,
})
}
else
{
taskID
,
err
=
s
.
soraClient
.
CreateVideoTask
(
reqCtx
,
account
,
SoraVideoRequest
{
Prompt
:
prompt
,
Orientation
:
modelCfg
.
Orientation
,
Frames
:
modelCfg
.
Frames
,
Model
:
modelCfg
.
Model
,
Size
:
modelCfg
.
Size
,
VideoCount
:
videoCount
,
MediaID
:
mediaID
,
RemixTargetID
:
remixTargetID
,
CameoIDs
:
extractSoraCameoIDs
(
reqBody
),
})
}
default
:
err
=
fmt
.
Errorf
(
"unsupported model type: %s"
,
modelCfg
.
Type
)
}
if
err
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
err
,
reqModel
,
c
,
clientStream
)
}
if
clientStream
&&
c
!=
nil
{
s
.
prepareSoraStream
(
c
,
taskID
)
}
var
mediaURLs
[]
string
videoGenerationID
:=
""
mediaType
:=
modelCfg
.
Type
imageCount
:=
0
imageSize
:=
""
switch
modelCfg
.
Type
{
case
"image"
:
urls
,
pollErr
:=
s
.
pollImageTask
(
reqCtx
,
c
,
account
,
taskID
,
clientStream
)
if
pollErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
pollErr
,
reqModel
,
c
,
clientStream
)
}
mediaURLs
=
urls
imageCount
=
len
(
urls
)
imageSize
=
soraImageSizeFromModel
(
reqModel
)
case
"video"
:
videoStatus
,
pollErr
:=
s
.
pollVideoTaskDetailed
(
reqCtx
,
c
,
account
,
taskID
,
clientStream
)
if
pollErr
!=
nil
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
pollErr
,
reqModel
,
c
,
clientStream
)
}
if
videoStatus
!=
nil
{
mediaURLs
=
videoStatus
.
URLs
videoGenerationID
=
strings
.
TrimSpace
(
videoStatus
.
GenerationID
)
}
default
:
mediaType
=
"prompt"
}
watermarkPostID
:=
""
if
modelCfg
.
Type
==
"video"
&&
watermarkOpts
.
Enabled
{
watermarkURL
,
postID
,
watermarkErr
:=
s
.
resolveWatermarkFreeURL
(
reqCtx
,
account
,
videoGenerationID
,
watermarkOpts
)
if
watermarkErr
!=
nil
{
if
!
watermarkOpts
.
FallbackOnFailure
{
return
nil
,
s
.
handleSoraRequestError
(
ctx
,
account
,
watermarkErr
,
reqModel
,
c
,
clientStream
)
}
log
.
Printf
(
"[Sora] watermark-free fallback to original URL, task_id=%s err=%v"
,
taskID
,
watermarkErr
)
}
else
if
strings
.
TrimSpace
(
watermarkURL
)
!=
""
{
mediaURLs
=
[]
string
{
strings
.
TrimSpace
(
watermarkURL
)}
watermarkPostID
=
strings
.
TrimSpace
(
postID
)
}
}
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs
:=
s
.
normalizeSoraMediaURLs
(
mediaURLs
)
if
watermarkPostID
!=
""
&&
watermarkOpts
.
DeletePost
{
if
deleteErr
:=
s
.
soraClient
.
DeletePost
(
reqCtx
,
account
,
watermarkPostID
);
deleteErr
!=
nil
{
log
.
Printf
(
"[Sora] delete post failed, post_id=%s err=%v"
,
watermarkPostID
,
deleteErr
)
}
}
content
:=
buildSoraContent
(
mediaType
,
finalURLs
)
var
firstTokenMs
*
int
if
clientStream
{
ms
,
streamErr
:=
s
.
writeSoraStream
(
c
,
reqModel
,
content
,
startTime
)
if
streamErr
!=
nil
{
return
nil
,
streamErr
}
firstTokenMs
=
ms
}
else
if
c
!=
nil
{
response
:=
buildSoraNonStreamResponse
(
content
,
reqModel
)
if
len
(
finalURLs
)
>
0
{
response
[
"media_url"
]
=
finalURLs
[
0
]
if
len
(
finalURLs
)
>
1
{
response
[
"media_urls"
]
=
finalURLs
}
}
c
.
JSON
(
http
.
StatusOK
,
response
)
}
return
&
ForwardResult
{
RequestID
:
taskID
,
Model
:
originalModel
,
UpstreamModel
:
upstreamModel
,
Stream
:
clientStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{},
MediaType
:
mediaType
,
MediaURL
:
firstMediaURL
(
finalURLs
),
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
func
(
s
*
SoraGatewayService
)
withSoraTimeout
(
ctx
context
.
Context
,
stream
bool
)
(
context
.
Context
,
context
.
CancelFunc
)
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
ctx
,
nil
}
timeoutSeconds
:=
s
.
cfg
.
Gateway
.
SoraRequestTimeoutSeconds
if
stream
{
timeoutSeconds
=
s
.
cfg
.
Gateway
.
SoraStreamTimeoutSeconds
}
if
timeoutSeconds
<=
0
{
return
ctx
,
nil
}
return
context
.
WithTimeout
(
ctx
,
time
.
Duration
(
timeoutSeconds
)
*
time
.
Second
)
}
func
parseSoraWatermarkOptions
(
body
map
[
string
]
any
)
soraWatermarkOptions
{
opts
:=
soraWatermarkOptions
{
Enabled
:
parseBoolWithDefault
(
body
,
"watermark_free"
,
false
),
ParseMethod
:
strings
.
ToLower
(
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_method"
,
"third_party"
))),
ParseURL
:
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_url"
,
""
)),
ParseToken
:
strings
.
TrimSpace
(
parseStringWithDefault
(
body
,
"watermark_parse_token"
,
""
)),
FallbackOnFailure
:
parseBoolWithDefault
(
body
,
"watermark_fallback_on_failure"
,
true
),
DeletePost
:
parseBoolWithDefault
(
body
,
"watermark_delete_post"
,
false
),
}
if
opts
.
ParseMethod
==
""
{
opts
.
ParseMethod
=
"third_party"
}
return
opts
}
func
parseSoraCharacterOptions
(
body
map
[
string
]
any
)
soraCharacterOptions
{
return
soraCharacterOptions
{
SetPublic
:
parseBoolWithDefault
(
body
,
"character_set_public"
,
true
),
DeleteAfterGenerate
:
parseBoolWithDefault
(
body
,
"character_delete_after_generate"
,
true
),
}
}
func
parseSoraVideoCount
(
body
map
[
string
]
any
)
int
{
if
body
==
nil
{
return
1
}
keys
:=
[]
string
{
"video_count"
,
"videos"
,
"n_variants"
}
for
_
,
key
:=
range
keys
{
count
:=
parseIntWithDefault
(
body
,
key
,
0
)
if
count
>
0
{
return
clampInt
(
count
,
1
,
3
)
}
}
return
1
}
func
parseBoolWithDefault
(
body
map
[
string
]
any
,
key
string
,
def
bool
)
bool
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
switch
typed
:=
val
.
(
type
)
{
case
bool
:
return
typed
case
int
:
return
typed
!=
0
case
int32
:
return
typed
!=
0
case
int64
:
return
typed
!=
0
case
float64
:
return
typed
!=
0
case
string
:
typed
=
strings
.
ToLower
(
strings
.
TrimSpace
(
typed
))
if
typed
==
"true"
||
typed
==
"1"
||
typed
==
"yes"
{
return
true
}
if
typed
==
"false"
||
typed
==
"0"
||
typed
==
"no"
{
return
false
}
}
return
def
}
func
parseStringWithDefault
(
body
map
[
string
]
any
,
key
,
def
string
)
string
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
if
str
,
ok
:=
val
.
(
string
);
ok
{
return
str
}
return
def
}
func
parseIntWithDefault
(
body
map
[
string
]
any
,
key
string
,
def
int
)
int
{
if
body
==
nil
{
return
def
}
val
,
ok
:=
body
[
key
]
if
!
ok
{
return
def
}
switch
typed
:=
val
.
(
type
)
{
case
int
:
return
typed
case
int32
:
return
int
(
typed
)
case
int64
:
return
int
(
typed
)
case
float64
:
return
int
(
typed
)
case
string
:
parsed
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
typed
))
if
err
==
nil
{
return
parsed
}
}
return
def
}
func
clampInt
(
v
,
minVal
,
maxVal
int
)
int
{
if
v
<
minVal
{
return
minVal
}
if
v
>
maxVal
{
return
maxVal
}
return
v
}
func
extractSoraCameoIDs
(
body
map
[
string
]
any
)
[]
string
{
if
body
==
nil
{
return
nil
}
raw
,
ok
:=
body
[
"cameo_ids"
]
if
!
ok
{
return
nil
}
switch
typed
:=
raw
.
(
type
)
{
case
[]
string
:
out
:=
make
([]
string
,
0
,
len
(
typed
))
for
_
,
item
:=
range
typed
{
item
=
strings
.
TrimSpace
(
item
)
if
item
!=
""
{
out
=
append
(
out
,
item
)
}
}
return
out
case
[]
any
:
out
:=
make
([]
string
,
0
,
len
(
typed
))
for
_
,
item
:=
range
typed
{
str
,
ok
:=
item
.
(
string
)
if
!
ok
{
continue
}
str
=
strings
.
TrimSpace
(
str
)
if
str
!=
""
{
out
=
append
(
out
,
str
)
}
}
return
out
default
:
return
nil
}
}
func
(
s
*
SoraGatewayService
)
createCharacterFromVideo
(
ctx
context
.
Context
,
account
*
Account
,
videoData
[]
byte
,
opts
soraCharacterOptions
)
(
*
soraCharacterFlowResult
,
error
)
{
cameoID
,
err
:=
s
.
soraClient
.
UploadCharacterVideo
(
ctx
,
account
,
videoData
)
if
err
!=
nil
{
return
nil
,
err
}
cameoStatus
,
err
:=
s
.
pollCameoStatus
(
ctx
,
account
,
cameoID
)
if
err
!=
nil
{
return
nil
,
err
}
username
:=
processSoraCharacterUsername
(
cameoStatus
.
UsernameHint
)
displayName
:=
strings
.
TrimSpace
(
cameoStatus
.
DisplayNameHint
)
if
displayName
==
""
{
displayName
=
"Character"
}
profileAssetURL
:=
strings
.
TrimSpace
(
cameoStatus
.
ProfileAssetURL
)
if
profileAssetURL
==
""
{
return
nil
,
errors
.
New
(
"profile asset url not found in cameo status"
)
}
avatarData
,
err
:=
s
.
soraClient
.
DownloadCharacterImage
(
ctx
,
account
,
profileAssetURL
)
if
err
!=
nil
{
return
nil
,
err
}
assetPointer
,
err
:=
s
.
soraClient
.
UploadCharacterImage
(
ctx
,
account
,
avatarData
)
if
err
!=
nil
{
return
nil
,
err
}
instructionSet
:=
cameoStatus
.
InstructionSetHint
if
instructionSet
==
nil
{
instructionSet
=
cameoStatus
.
InstructionSet
}
characterID
,
err
:=
s
.
soraClient
.
FinalizeCharacter
(
ctx
,
account
,
SoraCharacterFinalizeRequest
{
CameoID
:
strings
.
TrimSpace
(
cameoID
),
Username
:
username
,
DisplayName
:
displayName
,
ProfileAssetPointer
:
assetPointer
,
InstructionSet
:
instructionSet
,
})
if
err
!=
nil
{
return
nil
,
err
}
if
opts
.
SetPublic
{
if
err
:=
s
.
soraClient
.
SetCharacterPublic
(
ctx
,
account
,
cameoID
);
err
!=
nil
{
return
nil
,
err
}
}
return
&
soraCharacterFlowResult
{
CameoID
:
strings
.
TrimSpace
(
cameoID
),
CharacterID
:
strings
.
TrimSpace
(
characterID
),
Username
:
strings
.
TrimSpace
(
username
),
DisplayName
:
displayName
,
},
nil
}
func
(
s
*
SoraGatewayService
)
pollCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
timeout
:=
10
*
time
.
Minute
interval
:=
5
*
time
.
Second
maxAttempts
:=
int
(
math
.
Ceil
(
timeout
.
Seconds
()
/
interval
.
Seconds
()))
if
maxAttempts
<
1
{
maxAttempts
=
1
}
var
lastErr
error
consecutiveErrors
:=
0
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetCameoStatus
(
ctx
,
account
,
cameoID
)
if
err
!=
nil
{
lastErr
=
err
consecutiveErrors
++
if
consecutiveErrors
>=
3
{
break
}
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
continue
}
consecutiveErrors
=
0
if
status
==
nil
{
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
continue
}
currentStatus
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
status
.
Status
))
statusMessage
:=
strings
.
TrimSpace
(
status
.
StatusMessage
)
if
currentStatus
==
"failed"
{
if
statusMessage
==
""
{
statusMessage
=
"character creation failed"
}
return
nil
,
errors
.
New
(
statusMessage
)
}
if
strings
.
EqualFold
(
statusMessage
,
"Completed"
)
||
currentStatus
==
"finalized"
{
return
status
,
nil
}
if
attempt
<
maxAttempts
-
1
{
if
sleepErr
:=
sleepWithContext
(
ctx
,
interval
);
sleepErr
!=
nil
{
return
nil
,
sleepErr
}
}
}
if
lastErr
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"poll cameo status failed: %w"
,
lastErr
)
}
return
nil
,
errors
.
New
(
"cameo processing timeout"
)
}
func
processSoraCharacterUsername
(
usernameHint
string
)
string
{
usernameHint
=
strings
.
TrimSpace
(
usernameHint
)
if
usernameHint
==
""
{
usernameHint
=
"character"
}
if
strings
.
Contains
(
usernameHint
,
"."
)
{
parts
:=
strings
.
Split
(
usernameHint
,
"."
)
usernameHint
=
strings
.
TrimSpace
(
parts
[
len
(
parts
)
-
1
])
}
if
usernameHint
==
""
{
usernameHint
=
"character"
}
return
fmt
.
Sprintf
(
"%s%d"
,
usernameHint
,
rand
.
Intn
(
900
)
+
100
)
}
func
(
s
*
SoraGatewayService
)
resolveWatermarkFreeURL
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
,
opts
soraWatermarkOptions
)
(
string
,
string
,
error
)
{
generationID
=
strings
.
TrimSpace
(
generationID
)
if
generationID
==
""
{
return
""
,
""
,
errors
.
New
(
"generation id is required for watermark-free mode"
)
}
postID
,
err
:=
s
.
soraClient
.
PostVideoForWatermarkFree
(
ctx
,
account
,
generationID
)
if
err
!=
nil
{
return
""
,
""
,
err
}
postID
=
strings
.
TrimSpace
(
postID
)
if
postID
==
""
{
return
""
,
""
,
errors
.
New
(
"watermark-free publish returned empty post id"
)
}
switch
opts
.
ParseMethod
{
case
"custom"
:
urlVal
,
parseErr
:=
s
.
soraClient
.
GetWatermarkFreeURLCustom
(
ctx
,
account
,
opts
.
ParseURL
,
opts
.
ParseToken
,
postID
)
if
parseErr
!=
nil
{
return
""
,
postID
,
parseErr
}
return
strings
.
TrimSpace
(
urlVal
),
postID
,
nil
case
""
,
"third_party"
:
return
fmt
.
Sprintf
(
"https://oscdn2.dyysy.com/MP4/%s.mp4"
,
postID
),
postID
,
nil
default
:
return
""
,
postID
,
fmt
.
Errorf
(
"unsupported watermark parse method: %s"
,
opts
.
ParseMethod
)
}
}
func
(
s
*
SoraGatewayService
)
shouldFailoverUpstreamError
(
statusCode
int
)
bool
{
switch
statusCode
{
case
401
,
402
,
403
,
404
,
429
,
529
:
return
true
default
:
return
statusCode
>=
500
}
}
func
buildSoraNonStreamResponse
(
content
,
model
string
)
map
[
string
]
any
{
return
map
[
string
]
any
{
"id"
:
fmt
.
Sprintf
(
"chatcmpl-%d"
,
time
.
Now
()
.
UnixNano
()),
"object"
:
"chat.completion"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"message"
:
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
content
,
},
"finish_reason"
:
"stop"
,
},
},
}
}
func
soraImageSizeFromModel
(
model
string
)
string
{
modelLower
:=
strings
.
ToLower
(
model
)
if
size
,
ok
:=
soraImageSizeMap
[
modelLower
];
ok
{
return
size
}
if
strings
.
Contains
(
modelLower
,
"landscape"
)
||
strings
.
Contains
(
modelLower
,
"portrait"
)
{
return
"540"
}
return
"360"
}
func
soraProErrorMessage
(
model
,
upstreamMsg
string
)
string
{
modelLower
:=
strings
.
ToLower
(
model
)
if
strings
.
Contains
(
modelLower
,
"sora2pro-hd"
)
{
return
"当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
}
if
strings
.
Contains
(
modelLower
,
"sora2pro"
)
{
return
"当前账号无法使用 Sora Pro 模型,请更换模型或账号"
}
return
""
}
func
firstMediaURL
(
urls
[]
string
)
string
{
if
len
(
urls
)
==
0
{
return
""
}
return
urls
[
0
]
}
func
(
s
*
SoraGatewayService
)
buildSoraMediaURL
(
path
string
,
rawQuery
string
)
string
{
if
path
==
""
{
return
path
}
prefix
:=
"/sora/media"
values
:=
url
.
Values
{}
if
rawQuery
!=
""
{
if
parsed
,
err
:=
url
.
ParseQuery
(
rawQuery
);
err
==
nil
{
values
=
parsed
}
}
signKey
:=
""
ttlSeconds
:=
0
if
s
!=
nil
&&
s
.
cfg
!=
nil
{
signKey
=
strings
.
TrimSpace
(
s
.
cfg
.
Gateway
.
SoraMediaSigningKey
)
ttlSeconds
=
s
.
cfg
.
Gateway
.
SoraMediaSignedURLTTLSeconds
}
values
.
Del
(
"sig"
)
values
.
Del
(
"expires"
)
signingQuery
:=
values
.
Encode
()
if
signKey
!=
""
&&
ttlSeconds
>
0
{
expires
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
ttlSeconds
)
*
time
.
Second
)
.
Unix
()
signature
:=
SignSoraMediaURL
(
path
,
signingQuery
,
expires
,
signKey
)
if
signature
!=
""
{
values
.
Set
(
"expires"
,
strconv
.
FormatInt
(
expires
,
10
))
values
.
Set
(
"sig"
,
signature
)
prefix
=
"/sora/media-signed"
}
}
encoded
:=
values
.
Encode
()
if
encoded
==
""
{
return
prefix
+
path
}
return
prefix
+
path
+
"?"
+
encoded
}
func
(
s
*
SoraGatewayService
)
prepareSoraStream
(
c
*
gin
.
Context
,
requestID
string
)
{
if
c
==
nil
{
return
}
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
strings
.
TrimSpace
(
requestID
)
!=
""
{
c
.
Header
(
"x-request-id"
,
requestID
)
}
}
func
(
s
*
SoraGatewayService
)
writeSoraStream
(
c
*
gin
.
Context
,
model
,
content
string
,
startTime
time
.
Time
)
(
*
int
,
error
)
{
if
c
==
nil
{
return
nil
,
nil
}
writer
:=
c
.
Writer
flusher
,
_
:=
writer
.
(
http
.
Flusher
)
chunk
:=
map
[
string
]
any
{
"id"
:
fmt
.
Sprintf
(
"chatcmpl-%d"
,
time
.
Now
()
.
UnixNano
()),
"object"
:
"chat.completion.chunk"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
map
[
string
]
any
{
"content"
:
content
,
},
},
},
}
encoded
,
_
:=
jsonMarshalRaw
(
chunk
)
if
_
,
err
:=
fmt
.
Fprintf
(
writer
,
"data: %s
\n\n
"
,
encoded
);
err
!=
nil
{
return
nil
,
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
finalChunk
:=
map
[
string
]
any
{
"id"
:
chunk
[
"id"
],
"object"
:
"chat.completion.chunk"
,
"created"
:
time
.
Now
()
.
Unix
(),
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
map
[
string
]
any
{},
"finish_reason"
:
"stop"
,
},
},
}
finalEncoded
,
_
:=
jsonMarshalRaw
(
finalChunk
)
if
_
,
err
:=
fmt
.
Fprintf
(
writer
,
"data: %s
\n\n
"
,
finalEncoded
);
err
!=
nil
{
return
&
ms
,
err
}
if
_
,
err
:=
fmt
.
Fprint
(
writer
,
"data: [DONE]
\n\n
"
);
err
!=
nil
{
return
&
ms
,
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
&
ms
,
nil
}
func
(
s
*
SoraGatewayService
)
writeSoraError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
,
stream
bool
)
{
if
c
==
nil
{
return
}
if
stream
{
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
errorData
:=
map
[
string
]
any
{
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"event: error
\n
data: %s
\n\n
"
,
string
(
jsonBytes
))
_
,
_
=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
)
_
,
_
=
fmt
.
Fprint
(
c
.
Writer
,
"data: [DONE]
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
message
,
},
})
}
func
(
s
*
SoraGatewayService
)
handleSoraRequestError
(
ctx
context
.
Context
,
account
*
Account
,
err
error
,
model
string
,
c
*
gin
.
Context
,
stream
bool
)
error
{
if
err
==
nil
{
return
nil
}
var
upstreamErr
*
SoraUpstreamError
if
errors
.
As
(
err
,
&
upstreamErr
)
{
accountID
:=
int64
(
0
)
if
account
!=
nil
{
accountID
=
account
.
ID
}
logger
.
LegacyPrintf
(
"service.sora"
,
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s"
,
accountID
,
model
,
upstreamErr
.
StatusCode
,
strings
.
TrimSpace
(
upstreamErr
.
Headers
.
Get
(
"x-request-id"
)),
strings
.
TrimSpace
(
upstreamErr
.
Headers
.
Get
(
"cf-ray"
)),
strings
.
TrimSpace
(
upstreamErr
.
Message
),
truncateForLog
(
upstreamErr
.
Body
,
1024
),
)
if
s
.
rateLimitService
!=
nil
&&
account
!=
nil
{
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
upstreamErr
.
StatusCode
,
upstreamErr
.
Headers
,
upstreamErr
.
Body
)
}
if
s
.
shouldFailoverUpstreamError
(
upstreamErr
.
StatusCode
)
{
var
responseHeaders
http
.
Header
if
upstreamErr
.
Headers
!=
nil
{
responseHeaders
=
upstreamErr
.
Headers
.
Clone
()
}
return
&
UpstreamFailoverError
{
StatusCode
:
upstreamErr
.
StatusCode
,
ResponseBody
:
upstreamErr
.
Body
,
ResponseHeaders
:
responseHeaders
,
}
}
msg
:=
upstreamErr
.
Message
if
override
:=
soraProErrorMessage
(
model
,
msg
);
override
!=
""
{
msg
=
override
}
s
.
writeSoraError
(
c
,
upstreamErr
.
StatusCode
,
"upstream_error"
,
msg
,
stream
)
return
err
}
if
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
s
.
writeSoraError
(
c
,
http
.
StatusGatewayTimeout
,
"timeout_error"
,
"Sora generation timeout"
,
stream
)
return
err
}
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"api_error"
,
err
.
Error
(),
stream
)
return
err
}
func
(
s
*
SoraGatewayService
)
pollImageTask
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
taskID
string
,
stream
bool
)
([]
string
,
error
)
{
interval
:=
s
.
pollInterval
()
maxAttempts
:=
s
.
pollMaxAttempts
()
lastPing
:=
time
.
Now
()
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetImageTask
(
ctx
,
account
,
taskID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
strings
.
ToLower
(
status
.
Status
)
{
case
"succeeded"
,
"completed"
:
return
status
.
URLs
,
nil
case
"failed"
:
if
status
.
ErrorMsg
!=
""
{
return
nil
,
errors
.
New
(
status
.
ErrorMsg
)
}
return
nil
,
errors
.
New
(
"sora image generation failed"
)
}
if
stream
{
s
.
maybeSendPing
(
c
,
&
lastPing
)
}
if
err
:=
sleepWithContext
(
ctx
,
interval
);
err
!=
nil
{
return
nil
,
err
}
}
return
nil
,
errors
.
New
(
"sora image generation timeout"
)
}
func
(
s
*
SoraGatewayService
)
pollVideoTaskDetailed
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
taskID
string
,
stream
bool
)
(
*
SoraVideoTaskStatus
,
error
)
{
interval
:=
s
.
pollInterval
()
maxAttempts
:=
s
.
pollMaxAttempts
()
lastPing
:=
time
.
Now
()
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
status
,
err
:=
s
.
soraClient
.
GetVideoTask
(
ctx
,
account
,
taskID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
strings
.
ToLower
(
status
.
Status
)
{
case
"completed"
,
"succeeded"
:
return
status
,
nil
case
"failed"
:
if
status
.
ErrorMsg
!=
""
{
return
nil
,
errors
.
New
(
status
.
ErrorMsg
)
}
return
nil
,
errors
.
New
(
"sora video generation failed"
)
}
if
stream
{
s
.
maybeSendPing
(
c
,
&
lastPing
)
}
if
err
:=
sleepWithContext
(
ctx
,
interval
);
err
!=
nil
{
return
nil
,
err
}
}
return
nil
,
errors
.
New
(
"sora video generation timeout"
)
}
func
(
s
*
SoraGatewayService
)
pollInterval
()
time
.
Duration
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
2
*
time
.
Second
}
interval
:=
s
.
cfg
.
Sora
.
Client
.
PollIntervalSeconds
if
interval
<=
0
{
interval
=
2
}
return
time
.
Duration
(
interval
)
*
time
.
Second
}
func
(
s
*
SoraGatewayService
)
pollMaxAttempts
()
int
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
600
}
maxAttempts
:=
s
.
cfg
.
Sora
.
Client
.
MaxPollAttempts
if
maxAttempts
<=
0
{
maxAttempts
=
600
}
return
maxAttempts
}
func
(
s
*
SoraGatewayService
)
maybeSendPing
(
c
*
gin
.
Context
,
lastPing
*
time
.
Time
)
{
if
c
==
nil
{
return
}
interval
:=
10
*
time
.
Second
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
Concurrency
.
PingInterval
>
0
{
interval
=
time
.
Duration
(
s
.
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
}
if
time
.
Since
(
*
lastPing
)
<
interval
{
return
}
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
":
\n\n
"
);
err
==
nil
{
if
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
);
ok
{
flusher
.
Flush
()
}
*
lastPing
=
time
.
Now
()
}
}
func
(
s
*
SoraGatewayService
)
normalizeSoraMediaURLs
(
urls
[]
string
)
[]
string
{
if
len
(
urls
)
==
0
{
return
urls
}
output
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
raw
:=
range
urls
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
continue
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
output
=
append
(
output
,
raw
)
continue
}
pathVal
:=
raw
if
!
strings
.
HasPrefix
(
pathVal
,
"/"
)
{
pathVal
=
"/"
+
pathVal
}
output
=
append
(
output
,
s
.
buildSoraMediaURL
(
pathVal
,
""
))
}
return
output
}
// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
func
jsonMarshalRaw
(
v
any
)
([]
byte
,
error
)
{
var
buf
bytes
.
Buffer
enc
:=
json
.
NewEncoder
(
&
buf
)
enc
.
SetEscapeHTML
(
false
)
if
err
:=
enc
.
Encode
(
v
);
err
!=
nil
{
return
nil
,
err
}
// Encode 会追加换行符,去掉它
b
:=
buf
.
Bytes
()
if
len
(
b
)
>
0
&&
b
[
len
(
b
)
-
1
]
==
'\n'
{
b
=
b
[
:
len
(
b
)
-
1
]
}
return
b
,
nil
}
func
buildSoraContent
(
mediaType
string
,
urls
[]
string
)
string
{
switch
mediaType
{
case
"image"
:
parts
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
u
:=
range
urls
{
parts
=
append
(
parts
,
fmt
.
Sprintf
(
""
,
u
))
}
return
strings
.
Join
(
parts
,
"
\n
"
)
case
"video"
:
if
len
(
urls
)
==
0
{
return
""
}
return
fmt
.
Sprintf
(
"```html
\n
<video src='%s' controls></video>
\n
```"
,
urls
[
0
])
default
:
return
""
}
}
func
extractSoraInput
(
body
map
[
string
]
any
)
(
prompt
,
imageInput
,
videoInput
,
remixTargetID
string
)
{
if
body
==
nil
{
return
""
,
""
,
""
,
""
}
if
v
,
ok
:=
body
[
"remix_target_id"
]
.
(
string
);
ok
{
remixTargetID
=
strings
.
TrimSpace
(
v
)
}
if
v
,
ok
:=
body
[
"image"
]
.
(
string
);
ok
{
imageInput
=
v
}
if
v
,
ok
:=
body
[
"video"
]
.
(
string
);
ok
{
videoInput
=
v
}
if
v
,
ok
:=
body
[
"prompt"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
v
)
!=
""
{
prompt
=
v
}
if
messages
,
ok
:=
body
[
"messages"
]
.
([]
any
);
ok
{
builder
:=
strings
.
Builder
{}
for
_
,
raw
:=
range
messages
{
msg
,
ok
:=
raw
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
role
,
_
:=
msg
[
"role"
]
.
(
string
)
if
role
!=
""
&&
role
!=
"user"
{
continue
}
content
:=
msg
[
"content"
]
text
,
img
,
vid
:=
parseSoraMessageContent
(
content
)
if
text
!=
""
{
if
builder
.
Len
()
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n
"
)
}
_
,
_
=
builder
.
WriteString
(
text
)
}
if
imageInput
==
""
&&
img
!=
""
{
imageInput
=
img
}
if
videoInput
==
""
&&
vid
!=
""
{
videoInput
=
vid
}
}
if
prompt
==
""
{
prompt
=
builder
.
String
()
}
}
if
remixTargetID
==
""
{
remixTargetID
=
extractRemixTargetIDFromPrompt
(
prompt
)
}
prompt
=
cleanRemixLinkFromPrompt
(
prompt
)
return
prompt
,
imageInput
,
videoInput
,
remixTargetID
}
func
parseSoraMessageContent
(
content
any
)
(
text
,
imageInput
,
videoInput
string
)
{
switch
val
:=
content
.
(
type
)
{
case
string
:
return
val
,
""
,
""
case
[]
any
:
builder
:=
strings
.
Builder
{}
for
_
,
item
:=
range
val
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
t
,
_
:=
itemMap
[
"type"
]
.
(
string
)
switch
t
{
case
"text"
:
if
txt
,
ok
:=
itemMap
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
txt
)
!=
""
{
if
builder
.
Len
()
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n
"
)
}
_
,
_
=
builder
.
WriteString
(
txt
)
}
case
"image_url"
:
if
imageInput
==
""
{
if
urlVal
,
ok
:=
itemMap
[
"image_url"
]
.
(
map
[
string
]
any
);
ok
{
imageInput
=
fmt
.
Sprintf
(
"%v"
,
urlVal
[
"url"
])
}
else
if
urlStr
,
ok
:=
itemMap
[
"image_url"
]
.
(
string
);
ok
{
imageInput
=
urlStr
}
}
case
"video_url"
:
if
videoInput
==
""
{
if
urlVal
,
ok
:=
itemMap
[
"video_url"
]
.
(
map
[
string
]
any
);
ok
{
videoInput
=
fmt
.
Sprintf
(
"%v"
,
urlVal
[
"url"
])
}
else
if
urlStr
,
ok
:=
itemMap
[
"video_url"
]
.
(
string
);
ok
{
videoInput
=
urlStr
}
}
}
}
return
builder
.
String
(),
imageInput
,
videoInput
default
:
return
""
,
""
,
""
}
}
func
isSoraStoryboardPrompt
(
prompt
string
)
bool
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
false
}
return
len
(
soraStoryboardPattern
.
FindAllString
(
prompt
,
-
1
))
>=
1
}
func
formatSoraStoryboardPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
""
}
matches
:=
soraStoryboardShotPattern
.
FindAllStringSubmatch
(
prompt
,
-
1
)
if
len
(
matches
)
==
0
{
return
prompt
}
firstBracketPos
:=
strings
.
Index
(
prompt
,
"["
)
instructions
:=
""
if
firstBracketPos
>
0
{
instructions
=
strings
.
TrimSpace
(
prompt
[
:
firstBracketPos
])
}
shots
:=
make
([]
string
,
0
,
len
(
matches
))
for
i
,
match
:=
range
matches
{
if
len
(
match
)
<
3
{
continue
}
duration
:=
strings
.
TrimSpace
(
match
[
1
])
scene
:=
strings
.
TrimSpace
(
match
[
2
])
if
scene
==
""
{
continue
}
shots
=
append
(
shots
,
fmt
.
Sprintf
(
"Shot %d:
\n
duration: %ssec
\n
Scene: %s"
,
i
+
1
,
duration
,
scene
))
}
if
len
(
shots
)
==
0
{
return
prompt
}
timeline
:=
strings
.
Join
(
shots
,
"
\n\n
"
)
if
instructions
==
""
{
return
timeline
}
return
fmt
.
Sprintf
(
"current timeline:
\n
%s
\n\n
instructions:
\n
%s"
,
timeline
,
instructions
)
}
func
extractRemixTargetIDFromPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
""
}
return
strings
.
TrimSpace
(
soraRemixTargetPattern
.
FindString
(
prompt
))
}
func
cleanRemixLinkFromPrompt
(
prompt
string
)
string
{
prompt
=
strings
.
TrimSpace
(
prompt
)
if
prompt
==
""
{
return
prompt
}
cleaned
:=
soraRemixTargetInURLPattern
.
ReplaceAllString
(
prompt
,
""
)
cleaned
=
soraRemixTargetPattern
.
ReplaceAllString
(
cleaned
,
""
)
cleaned
=
strings
.
Join
(
strings
.
Fields
(
cleaned
),
" "
)
return
strings
.
TrimSpace
(
cleaned
)
}
func
decodeSoraImageInput
(
ctx
context
.
Context
,
input
string
)
([]
byte
,
string
,
error
)
{
raw
:=
strings
.
TrimSpace
(
input
)
if
raw
==
""
{
return
nil
,
""
,
errors
.
New
(
"empty image input"
)
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
parts
:=
strings
.
SplitN
(
raw
,
","
,
2
)
if
len
(
parts
)
!=
2
{
return
nil
,
""
,
errors
.
New
(
"invalid data url"
)
}
meta
:=
parts
[
0
]
payload
:=
parts
[
1
]
decoded
,
err
:=
decodeBase64WithLimit
(
payload
,
soraImageInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
ext
:=
""
if
strings
.
HasPrefix
(
meta
,
"data:"
)
{
metaParts
:=
strings
.
SplitN
(
meta
[
5
:
],
";"
,
2
)
if
len
(
metaParts
)
>
0
{
if
exts
,
err
:=
mime
.
ExtensionsByType
(
metaParts
[
0
]);
err
==
nil
&&
len
(
exts
)
>
0
{
ext
=
exts
[
0
]
}
}
}
filename
:=
"image"
+
ext
return
decoded
,
filename
,
nil
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
return
downloadSoraImageInput
(
ctx
,
raw
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
raw
,
soraImageInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
""
,
errors
.
New
(
"invalid base64 image"
)
}
return
decoded
,
"image.png"
,
nil
}
func
decodeSoraVideoInput
(
ctx
context
.
Context
,
input
string
)
([]
byte
,
error
)
{
raw
:=
strings
.
TrimSpace
(
input
)
if
raw
==
""
{
return
nil
,
errors
.
New
(
"empty video input"
)
}
if
strings
.
HasPrefix
(
raw
,
"data:"
)
{
parts
:=
strings
.
SplitN
(
raw
,
","
,
2
)
if
len
(
parts
)
!=
2
{
return
nil
,
errors
.
New
(
"invalid video data url"
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
parts
[
1
],
soraVideoInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
errors
.
New
(
"invalid base64 video"
)
}
if
len
(
decoded
)
==
0
{
return
nil
,
errors
.
New
(
"empty video data"
)
}
return
decoded
,
nil
}
if
strings
.
HasPrefix
(
raw
,
"http://"
)
||
strings
.
HasPrefix
(
raw
,
"https://"
)
{
return
downloadSoraVideoInput
(
ctx
,
raw
)
}
decoded
,
err
:=
decodeBase64WithLimit
(
raw
,
soraVideoInputMaxBytes
)
if
err
!=
nil
{
return
nil
,
errors
.
New
(
"invalid base64 video"
)
}
if
len
(
decoded
)
==
0
{
return
nil
,
errors
.
New
(
"empty video data"
)
}
return
decoded
,
nil
}
func
downloadSoraImageInput
(
ctx
context
.
Context
,
rawURL
string
)
([]
byte
,
string
,
error
)
{
parsed
,
err
:=
validateSoraRemoteURL
(
rawURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
parsed
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
soraImageInputTimeout
,
CheckRedirect
:
func
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
soraImageInputMaxRedirects
{
return
errors
.
New
(
"too many redirects"
)
}
return
validateSoraRemoteURLValue
(
req
.
URL
)
},
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
""
,
fmt
.
Errorf
(
"download image failed: %d"
,
resp
.
StatusCode
)
}
data
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
soraImageInputMaxBytes
))
if
err
!=
nil
{
return
nil
,
""
,
err
}
ext
:=
fileExtFromURL
(
parsed
.
String
())
if
ext
==
""
{
ext
=
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
))
}
filename
:=
"image"
+
ext
return
data
,
filename
,
nil
}
func
downloadSoraVideoInput
(
ctx
context
.
Context
,
rawURL
string
)
([]
byte
,
error
)
{
parsed
,
err
:=
validateSoraRemoteURL
(
rawURL
)
if
err
!=
nil
{
return
nil
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
parsed
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
soraVideoInputTimeout
,
CheckRedirect
:
func
(
req
*
http
.
Request
,
via
[]
*
http
.
Request
)
error
{
if
len
(
via
)
>=
soraVideoInputMaxRedirects
{
return
errors
.
New
(
"too many redirects"
)
}
return
validateSoraRemoteURLValue
(
req
.
URL
)
},
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"download video failed: %d"
,
resp
.
StatusCode
)
}
data
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
soraVideoInputMaxBytes
))
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
data
)
==
0
{
return
nil
,
errors
.
New
(
"empty video content"
)
}
return
data
,
nil
}
func
decodeBase64WithLimit
(
encoded
string
,
maxBytes
int64
)
([]
byte
,
error
)
{
if
maxBytes
<=
0
{
return
nil
,
errors
.
New
(
"invalid max bytes limit"
)
}
decoder
:=
base64
.
NewDecoder
(
base64
.
StdEncoding
,
strings
.
NewReader
(
encoded
))
limited
:=
io
.
LimitReader
(
decoder
,
maxBytes
+
1
)
data
,
err
:=
io
.
ReadAll
(
limited
)
if
err
!=
nil
{
return
nil
,
err
}
if
int64
(
len
(
data
))
>
maxBytes
{
return
nil
,
fmt
.
Errorf
(
"input exceeds %d bytes limit"
,
maxBytes
)
}
return
data
,
nil
}
func
validateSoraRemoteURL
(
raw
string
)
(
*
url
.
URL
,
error
)
{
if
strings
.
TrimSpace
(
raw
)
==
""
{
return
nil
,
errors
.
New
(
"empty remote url"
)
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid remote url: %w"
,
err
)
}
if
err
:=
validateSoraRemoteURLValue
(
parsed
);
err
!=
nil
{
return
nil
,
err
}
return
parsed
,
nil
}
func
validateSoraRemoteURLValue
(
parsed
*
url
.
URL
)
error
{
if
parsed
==
nil
{
return
errors
.
New
(
"invalid remote url"
)
}
scheme
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
parsed
.
Scheme
))
if
scheme
!=
"http"
&&
scheme
!=
"https"
{
return
errors
.
New
(
"only http/https remote url is allowed"
)
}
if
parsed
.
User
!=
nil
{
return
errors
.
New
(
"remote url cannot contain userinfo"
)
}
host
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
parsed
.
Hostname
()))
if
host
==
""
{
return
errors
.
New
(
"remote url missing host"
)
}
if
_
,
blocked
:=
soraBlockedHostnames
[
host
];
blocked
{
return
errors
.
New
(
"remote url is not allowed"
)
}
if
ip
:=
net
.
ParseIP
(
host
);
ip
!=
nil
{
if
isSoraBlockedIP
(
ip
)
{
return
errors
.
New
(
"remote url is not allowed"
)
}
return
nil
}
ips
,
err
:=
net
.
LookupIP
(
host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"resolve remote url failed: %w"
,
err
)
}
for
_
,
ip
:=
range
ips
{
if
isSoraBlockedIP
(
ip
)
{
return
errors
.
New
(
"remote url is not allowed"
)
}
}
return
nil
}
func
isSoraBlockedIP
(
ip
net
.
IP
)
bool
{
if
ip
==
nil
{
return
true
}
for
_
,
cidr
:=
range
soraBlockedCIDRs
{
if
cidr
.
Contains
(
ip
)
{
return
true
}
}
return
false
}
func
mustParseCIDRs
(
values
[]
string
)
[]
*
net
.
IPNet
{
out
:=
make
([]
*
net
.
IPNet
,
0
,
len
(
values
))
for
_
,
val
:=
range
values
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
val
)
if
err
!=
nil
{
continue
}
out
=
append
(
out
,
cidr
)
}
return
out
}
backend/internal/service/sora_gateway_service_test.go
deleted
100644 → 0
View file @
f585a15e
//go:build unit
package
service
import
(
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var
_
SoraClient
=
(
*
stubSoraClientForPoll
)(
nil
)
type
stubSoraClientForPoll
struct
{
imageStatus
*
SoraImageTaskStatus
videoStatus
*
SoraVideoTaskStatus
imageCalls
int
videoCalls
int
enhanced
string
enhanceErr
error
storyboard
bool
videoReq
SoraVideoRequest
parseErr
error
postCalls
int
deleteCalls
int
}
func
(
s
*
stubSoraClientForPoll
)
Enabled
()
bool
{
return
true
}
func
(
s
*
stubSoraClientForPoll
)
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
{
return
"task-image"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
{
s
.
videoReq
=
req
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
{
s
.
storyboard
=
true
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"cameo-1"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
return
&
SoraCameoStatus
{
Status
:
"finalized"
,
StatusMessage
:
"Completed"
,
DisplayNameHint
:
"Character"
,
UsernameHint
:
"user.character"
,
ProfileAssetURL
:
"https://example.com/avatar.webp"
,
},
nil
}
func
(
s
*
stubSoraClientForPoll
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
{
return
[]
byte
(
"avatar"
),
nil
}
func
(
s
*
stubSoraClientForPoll
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"asset-pointer"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
"character-1"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
{
s
.
postCalls
++
return
"s_post"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
{
s
.
deleteCalls
++
return
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
if
s
.
parseErr
!=
nil
{
return
""
,
s
.
parseErr
}
return
"https://example.com/no-watermark.mp4"
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
if
s
.
enhanced
!=
""
{
return
s
.
enhanced
,
s
.
enhanceErr
}
return
"enhanced prompt"
,
s
.
enhanceErr
}
func
(
s
*
stubSoraClientForPoll
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
{
s
.
imageCalls
++
return
s
.
imageStatus
,
nil
}
func
(
s
*
stubSoraClientForPoll
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
{
s
.
videoCalls
++
return
s
.
videoStatus
,
nil
}
func
TestSoraGatewayService_PollImageTaskCompleted
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
imageStatus
:
&
SoraImageTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/a.png"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
service
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
urls
,
err
:=
service
.
pollImageTask
(
context
.
Background
(),
nil
,
&
Account
{
ID
:
1
},
"task"
,
false
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"https://example.com/a.png"
},
urls
)
require
.
Equal
(
t
,
1
,
client
.
imageCalls
)
}
func
TestSoraGatewayService_ForwardPromptEnhance
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
enhanced
:
"cinematic prompt"
,
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"prompt-enhance-short-10s"
:
"prompt-enhance-short-15s"
,
},
},
}
body
:=
[]
byte
(
`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"prompt"
,
result
.
MediaType
)
require
.
Equal
(
t
,
"prompt-enhance-short-10s"
,
result
.
Model
)
require
.
Equal
(
t
,
"prompt-enhance-short-15s"
,
result
.
UpstreamModel
)
}
func
TestSoraGatewayService_ForwardStoryboardPrompt
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/v.mp4"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
client
.
storyboard
)
}
func
TestSoraGatewayService_ForwardVideoCount
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/v.mp4"
},
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
3
,
client
.
videoReq
.
VideoCount
)
}
func
TestSoraGatewayService_ForwardCharacterOnly
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"prompt"
,
result
.
MediaType
)
require
.
Equal
(
t
,
0
,
client
.
videoCalls
)
}
func
TestSoraGatewayService_ForwardWatermarkFallback
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/original.mp4"
},
GenerationID
:
"gen_1"
,
},
parseErr
:
errors
.
New
(
"parse failed"
),
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"https://example.com/original.mp4"
,
result
.
MediaURL
)
require
.
Equal
(
t
,
1
,
client
.
postCalls
)
require
.
Equal
(
t
,
0
,
client
.
deleteCalls
)
}
func
TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"completed"
,
URLs
:
[]
string
{
"https://example.com/original.mp4"
},
GenerationID
:
"gen_1"
,
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
svc
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
}
body
:=
[]
byte
(
`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`
)
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
nil
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"https://example.com/no-watermark.mp4"
,
result
.
MediaURL
)
require
.
Equal
(
t
,
1
,
client
.
postCalls
)
require
.
Equal
(
t
,
1
,
client
.
deleteCalls
)
}
func
TestSoraGatewayService_PollVideoTaskFailed
(
t
*
testing
.
T
)
{
client
:=
&
stubSoraClientForPoll
{
videoStatus
:
&
SoraVideoTaskStatus
{
Status
:
"failed"
,
ErrorMsg
:
"reject"
,
},
}
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
1
,
MaxPollAttempts
:
1
,
},
},
}
service
:=
NewSoraGatewayService
(
client
,
nil
,
nil
,
cfg
)
status
,
err
:=
service
.
pollVideoTaskDetailed
(
context
.
Background
(),
nil
,
&
Account
{
ID
:
1
},
"task"
,
false
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
status
)
require
.
Contains
(
t
,
err
.
Error
(),
"reject"
)
require
.
Equal
(
t
,
1
,
client
.
videoCalls
)
}
func
TestSoraGatewayService_BuildSoraMediaURLSigned
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
SoraMediaSigningKey
:
"test-key"
,
SoraMediaSignedURLTTLSeconds
:
600
,
},
}
service
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
url
:=
service
.
buildSoraMediaURL
(
"/image/2025/01/01/a.png"
,
""
)
require
.
Contains
(
t
,
url
,
"/sora/media-signed"
)
require
.
Contains
(
t
,
url
,
"expires="
)
require
.
Contains
(
t
,
url
,
"sig="
)
}
func
TestNormalizeSoraMediaURLs_Empty
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
result
:=
svc
.
normalizeSoraMediaURLs
(
nil
)
require
.
Empty
(
t
,
result
)
result
=
svc
.
normalizeSoraMediaURLs
([]
string
{})
require
.
Empty
(
t
,
result
)
}
func
TestNormalizeSoraMediaURLs_HTTPUrls
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
urls
:=
[]
string
{
"https://example.com/a.png"
,
"http://example.com/b.mp4"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Equal
(
t
,
urls
,
result
)
}
func
TestNormalizeSoraMediaURLs_LocalPaths
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
urls
:=
[]
string
{
"/image/2025/01/a.png"
,
"video/2025/01/b.mp4"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Len
(
t
,
result
,
2
)
require
.
Contains
(
t
,
result
[
0
],
"/sora/media"
)
require
.
Contains
(
t
,
result
[
1
],
"/sora/media"
)
}
func
TestNormalizeSoraMediaURLs_SkipsBlank
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
urls
:=
[]
string
{
"https://example.com/a.png"
,
""
,
" "
,
"https://example.com/b.png"
}
result
:=
svc
.
normalizeSoraMediaURLs
(
urls
)
require
.
Len
(
t
,
result
,
2
)
}
func
TestBuildSoraContent_Image
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"image"
,
[]
string
{
"https://a.com/1.png"
,
"https://a.com/2.png"
})
require
.
Contains
(
t
,
content
,
""
)
require
.
Contains
(
t
,
content
,
""
)
}
func
TestBuildSoraContent_Video
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"video"
,
[]
string
{
"https://a.com/v.mp4"
})
require
.
Contains
(
t
,
content
,
"<video src='https://a.com/v.mp4'"
)
}
func
TestBuildSoraContent_VideoEmpty
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"video"
,
nil
)
require
.
Empty
(
t
,
content
)
}
func
TestBuildSoraContent_Prompt
(
t
*
testing
.
T
)
{
content
:=
buildSoraContent
(
"prompt"
,
nil
)
require
.
Empty
(
t
,
content
)
}
func
TestSoraImageSizeFromModel
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"360"
,
soraImageSizeFromModel
(
"gpt-image"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"gpt-image-landscape"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"gpt-image-portrait"
))
require
.
Equal
(
t
,
"540"
,
soraImageSizeFromModel
(
"something-landscape"
))
require
.
Equal
(
t
,
"360"
,
soraImageSizeFromModel
(
"unknown-model"
))
}
func
TestFirstMediaURL
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
""
,
firstMediaURL
(
nil
))
require
.
Equal
(
t
,
""
,
firstMediaURL
([]
string
{}))
require
.
Equal
(
t
,
"a"
,
firstMediaURL
([]
string
{
"a"
,
"b"
}))
}
func
TestSoraProErrorMessage
(
t
*
testing
.
T
)
{
require
.
Contains
(
t
,
soraProErrorMessage
(
"sora2pro-hd"
,
""
),
"Pro-HD"
)
require
.
Contains
(
t
,
soraProErrorMessage
(
"sora2pro"
,
""
),
"Pro"
)
require
.
Empty
(
t
,
soraProErrorMessage
(
"sora-basic"
,
""
))
}
func
TestSoraGatewayService_WriteSoraError_StreamEscapesJSON
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
svc
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"invalid
\"
prompt
\"\n
line2"
,
true
)
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: error
\n
"
)
require
.
Contains
(
t
,
body
,
"data: [DONE]
\n\n
"
)
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
require
.
GreaterOrEqual
(
t
,
len
(
lines
),
2
)
require
.
Equal
(
t
,
"event: error"
,
lines
[
0
])
require
.
True
(
t
,
strings
.
HasPrefix
(
lines
[
1
],
"data: "
))
data
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
data
),
&
parsed
))
errObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errObj
[
"type"
])
require
.
Equal
(
t
,
"invalid
\"
prompt
\"\n
line2"
,
errObj
[
"message"
])
}
func
TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
sourceHeaders
:=
http
.
Header
{}
sourceHeaders
.
Set
(
"cf-ray"
,
"9d01b0e9ecc35829-SEA"
)
err
:=
svc
.
handleSoraRequestError
(
context
.
Background
(),
&
Account
{
ID
:
1
,
Platform
:
PlatformSora
},
&
SoraUpstreamError
{
StatusCode
:
http
.
StatusForbidden
,
Message
:
"forbidden"
,
Headers
:
sourceHeaders
,
Body
:
[]
byte
(
`<!DOCTYPE html><title>Just a moment...</title>`
),
},
"sora2-landscape-10s"
,
nil
,
false
,
)
var
failoverErr
*
UpstreamFailoverError
require
.
ErrorAs
(
t
,
err
,
&
failoverErr
)
require
.
NotNil
(
t
,
failoverErr
.
ResponseHeaders
)
require
.
Equal
(
t
,
"9d01b0e9ecc35829-SEA"
,
failoverErr
.
ResponseHeaders
.
Get
(
"cf-ray"
))
sourceHeaders
.
Set
(
"cf-ray"
,
"mutated-after-return"
)
require
.
Equal
(
t
,
"9d01b0e9ecc35829-SEA"
,
failoverErr
.
ResponseHeaders
.
Get
(
"cf-ray"
))
}
func
TestShouldFailoverUpstreamError
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
401
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
404
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
429
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
500
))
require
.
True
(
t
,
svc
.
shouldFailoverUpstreamError
(
502
))
require
.
False
(
t
,
svc
.
shouldFailoverUpstreamError
(
200
))
require
.
False
(
t
,
svc
.
shouldFailoverUpstreamError
(
400
))
}
func
TestWithSoraTimeout_NilService
(
t
*
testing
.
T
)
{
var
svc
*
SoraGatewayService
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
Nil
(
t
,
cancel
)
}
func
TestWithSoraTimeout_ZeroTimeout
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
Nil
(
t
,
cancel
)
}
func
TestWithSoraTimeout_PositiveTimeout
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
SoraRequestTimeoutSeconds
:
30
,
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
ctx
,
cancel
:=
svc
.
withSoraTimeout
(
context
.
Background
(),
false
)
require
.
NotNil
(
t
,
ctx
)
require
.
NotNil
(
t
,
cancel
)
cancel
()
}
func
TestPollInterval
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
PollIntervalSeconds
:
5
,
},
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
require
.
Equal
(
t
,
5
*
time
.
Second
,
svc
.
pollInterval
())
// 默认值
svc2
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc2
.
pollInterval
()
>
0
)
}
func
TestPollMaxAttempts
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Client
:
config
.
SoraClientConfig
{
MaxPollAttempts
:
100
,
},
},
}
svc
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
cfg
)
require
.
Equal
(
t
,
100
,
svc
.
pollMaxAttempts
())
// 默认值
svc2
:=
NewSoraGatewayService
(
nil
,
nil
,
nil
,
&
config
.
Config
{})
require
.
True
(
t
,
svc2
.
pollMaxAttempts
()
>
0
)
}
func
TestDecodeSoraImageInput_BlockPrivateURL
(
t
*
testing
.
T
)
{
_
,
_
,
err
:=
decodeSoraImageInput
(
context
.
Background
(),
"http://127.0.0.1/internal.png"
)
require
.
Error
(
t
,
err
)
}
func
TestDecodeSoraImageInput_DataURL
(
t
*
testing
.
T
)
{
encoded
:=
"data:image/png;base64,aGVsbG8="
data
,
filename
,
err
:=
decodeSoraImageInput
(
context
.
Background
(),
encoded
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
data
)
require
.
Contains
(
t
,
filename
,
".png"
)
}
func
TestDecodeBase64WithLimit_ExceedLimit
(
t
*
testing
.
T
)
{
data
,
err
:=
decodeBase64WithLimit
(
"aGVsbG8="
,
3
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
data
)
}
func
TestParseSoraWatermarkOptions_NumericBool
(
t
*
testing
.
T
)
{
body
:=
map
[
string
]
any
{
"watermark_free"
:
float64
(
1
),
"watermark_fallback_on_failure"
:
float64
(
0
),
}
opts
:=
parseSoraWatermarkOptions
(
body
)
require
.
True
(
t
,
opts
.
Enabled
)
require
.
False
(
t
,
opts
.
FallbackOnFailure
)
}
func
TestParseSoraVideoCount
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
1
,
parseSoraVideoCount
(
nil
))
require
.
Equal
(
t
,
2
,
parseSoraVideoCount
(
map
[
string
]
any
{
"video_count"
:
float64
(
2
)}))
require
.
Equal
(
t
,
3
,
parseSoraVideoCount
(
map
[
string
]
any
{
"videos"
:
"5"
}))
require
.
Equal
(
t
,
1
,
parseSoraVideoCount
(
map
[
string
]
any
{
"n_variants"
:
0
}))
}
backend/internal/service/sora_gateway_streaming_legacy.go
deleted
100644 → 0
View file @
f585a15e
//nolint:unused
package
service
import
(
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var
soraSSEDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
var
soraImageMarkdownRe
=
regexp
.
MustCompile
(
`!\[[^\]]*\]\(([^)]+)\)`
)
var
soraVideoHTMLRe
=
regexp
.
MustCompile
(
`(?i)<video[^>]+src=['"]([^'"]+)['"]`
)
const
soraRewriteBufferLimit
=
2048
type
soraStreamingResult
struct
{
mediaType
string
mediaURLs
[]
string
imageCount
int
imageSize
string
firstTokenMs
*
int
}
func
(
s
*
SoraGatewayService
)
setUpstreamRequestError
(
c
*
gin
.
Context
,
account
*
Account
,
err
error
)
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
}
}
func
(
s
*
SoraGatewayService
)
handleFailoverSideEffects
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
account
*
Account
)
{
if
s
.
rateLimitService
==
nil
||
account
==
nil
||
resp
==
nil
{
return
}
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
func
(
s
*
SoraGatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
reqModel
string
)
(
*
ForwardResult
,
error
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
msg
:=
soraProErrorMessage
(
reqModel
,
upstreamMsg
);
msg
!=
""
{
upstreamMsg
=
msg
}
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
c
!=
nil
{
responsePayload
:=
s
.
buildErrorPayload
(
respBody
,
upstreamMsg
)
c
.
JSON
(
resp
.
StatusCode
,
responsePayload
)
}
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
func
(
s
*
SoraGatewayService
)
buildErrorPayload
(
respBody
[]
byte
,
overrideMessage
string
)
map
[
string
]
any
{
if
len
(
respBody
)
>
0
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
respBody
,
&
payload
);
err
==
nil
{
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
overrideMessage
!=
""
{
errObj
[
"message"
]
=
overrideMessage
}
payload
[
"error"
]
=
errObj
return
payload
}
}
}
return
map
[
string
]
any
{
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
overrideMessage
,
},
}
}
func
(
s
*
SoraGatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
string
,
clientStream
bool
)
(
*
soraStreamingResult
,
error
)
{
if
resp
==
nil
{
return
nil
,
errors
.
New
(
"empty response"
)
}
if
clientStream
{
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
v
:=
resp
.
Header
.
Get
(
"x-request-id"
);
v
!=
""
{
c
.
Header
(
"x-request-id"
,
v
)
}
}
w
:=
c
.
Writer
flusher
,
_
:=
w
.
(
http
.
Flusher
)
contentBuilder
:=
strings
.
Builder
{}
var
firstTokenMs
*
int
var
upstreamError
error
rewriteBuffer
:=
""
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
sendLine
:=
func
(
line
string
)
error
{
if
!
clientStream
{
return
nil
}
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
err
}
if
flusher
!=
nil
{
flusher
.
Flush
()
}
return
nil
}
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
if
soraSSEDataRe
.
MatchString
(
line
)
{
data
:=
soraSSEDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
"[DONE]"
{
if
rewriteBuffer
!=
""
{
flushLine
,
flushContent
,
err
:=
s
.
flushSoraRewriteBuffer
(
rewriteBuffer
,
originalModel
)
if
err
!=
nil
{
return
nil
,
err
}
if
flushLine
!=
""
{
if
flushContent
!=
""
{
if
_
,
err
:=
contentBuilder
.
WriteString
(
flushContent
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
sendLine
(
flushLine
);
err
!=
nil
{
return
nil
,
err
}
}
rewriteBuffer
=
""
}
if
err
:=
sendLine
(
"data: [DONE]"
);
err
!=
nil
{
return
nil
,
err
}
break
}
updatedLine
,
contentDelta
,
errEvent
:=
s
.
processSoraSSEData
(
data
,
originalModel
,
&
rewriteBuffer
)
if
errEvent
!=
nil
&&
upstreamError
==
nil
{
upstreamError
=
errEvent
}
if
contentDelta
!=
""
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
if
_
,
err
:=
contentBuilder
.
WriteString
(
contentDelta
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
sendLine
(
updatedLine
);
err
!=
nil
{
return
nil
,
err
}
continue
}
if
err
:=
sendLine
(
line
);
err
!=
nil
{
return
nil
,
err
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
if
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
if
clientStream
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
response_too_large
\"
}
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
}
return
nil
,
err
}
if
ctx
.
Err
()
==
context
.
DeadlineExceeded
&&
s
.
rateLimitService
!=
nil
&&
account
!=
nil
{
s
.
rateLimitService
.
HandleStreamTimeout
(
ctx
,
account
,
originalModel
)
}
if
clientStream
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
stream_read_error
\"
}
\n\n
"
)
if
flusher
!=
nil
{
flusher
.
Flush
()
}
}
return
nil
,
err
}
content
:=
contentBuilder
.
String
()
mediaType
,
mediaURLs
:=
s
.
extractSoraMedia
(
content
)
if
mediaType
==
""
&&
isSoraPromptEnhanceModel
(
originalModel
)
{
mediaType
=
"prompt"
}
imageSize
:=
""
imageCount
:=
0
if
mediaType
==
"image"
{
imageSize
=
soraImageSizeFromModel
(
originalModel
)
imageCount
=
len
(
mediaURLs
)
}
if
upstreamError
!=
nil
&&
!
clientStream
{
if
c
!=
nil
{
c
.
JSON
(
http
.
StatusBadGateway
,
map
[
string
]
any
{
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
upstreamError
.
Error
(),
},
})
}
return
nil
,
upstreamError
}
if
!
clientStream
{
response
:=
buildSoraNonStreamResponse
(
content
,
originalModel
)
if
len
(
mediaURLs
)
>
0
{
response
[
"media_url"
]
=
mediaURLs
[
0
]
if
len
(
mediaURLs
)
>
1
{
response
[
"media_urls"
]
=
mediaURLs
}
}
c
.
JSON
(
http
.
StatusOK
,
response
)
}
return
&
soraStreamingResult
{
mediaType
:
mediaType
,
mediaURLs
:
mediaURLs
,
imageCount
:
imageCount
,
imageSize
:
imageSize
,
firstTokenMs
:
firstTokenMs
,
},
nil
}
func
(
s
*
SoraGatewayService
)
processSoraSSEData
(
data
string
,
originalModel
string
,
rewriteBuffer
*
string
)
(
string
,
string
,
error
)
{
if
strings
.
TrimSpace
(
data
)
==
""
{
return
"data: "
,
""
,
nil
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
payload
);
err
!=
nil
{
return
"data: "
+
data
,
""
,
nil
}
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
return
"data: "
+
data
,
""
,
errors
.
New
(
msg
)
}
}
if
model
,
ok
:=
payload
[
"model"
]
.
(
string
);
ok
&&
model
!=
""
&&
originalModel
!=
""
{
payload
[
"model"
]
=
originalModel
}
contentDelta
,
updated
:=
extractSoraContent
(
payload
)
if
updated
{
var
rewritten
string
if
rewriteBuffer
!=
nil
{
rewritten
=
s
.
rewriteSoraContentWithBuffer
(
contentDelta
,
rewriteBuffer
)
}
else
{
rewritten
=
s
.
rewriteSoraContent
(
contentDelta
)
}
if
rewritten
!=
contentDelta
{
applySoraContent
(
payload
,
rewritten
)
contentDelta
=
rewritten
}
}
updatedData
,
err
:=
jsonMarshalRaw
(
payload
)
if
err
!=
nil
{
return
"data: "
+
data
,
contentDelta
,
nil
}
return
"data: "
+
string
(
updatedData
),
contentDelta
,
nil
}
func
extractSoraContent
(
payload
map
[
string
]
any
)
(
string
,
bool
)
{
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
return
""
,
false
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
""
,
false
}
if
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
delta
[
"content"
]
.
(
string
);
ok
{
return
content
,
true
}
}
if
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
message
[
"content"
]
.
(
string
);
ok
{
return
content
,
true
}
}
return
""
,
false
}
func
applySoraContent
(
payload
map
[
string
]
any
,
content
string
)
{
choices
,
ok
:=
payload
[
"choices"
]
.
([]
any
)
if
!
ok
||
len
(
choices
)
==
0
{
return
}
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
if
delta
,
ok
:=
choice
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
delta
[
"content"
]
=
content
choice
[
"delta"
]
=
delta
return
}
if
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
message
[
"content"
]
=
content
choice
[
"message"
]
=
message
}
}
func
(
s
*
SoraGatewayService
)
rewriteSoraContentWithBuffer
(
contentDelta
string
,
buffer
*
string
)
string
{
if
buffer
==
nil
{
return
s
.
rewriteSoraContent
(
contentDelta
)
}
if
contentDelta
==
""
&&
*
buffer
==
""
{
return
""
}
combined
:=
*
buffer
+
contentDelta
rewritten
:=
s
.
rewriteSoraContent
(
combined
)
bufferStart
:=
s
.
findSoraRewriteBufferStart
(
rewritten
)
if
bufferStart
<
0
{
*
buffer
=
""
return
rewritten
}
if
len
(
rewritten
)
-
bufferStart
>
soraRewriteBufferLimit
{
bufferStart
=
len
(
rewritten
)
-
soraRewriteBufferLimit
}
output
:=
rewritten
[
:
bufferStart
]
*
buffer
=
rewritten
[
bufferStart
:
]
return
output
}
func
(
s
*
SoraGatewayService
)
findSoraRewriteBufferStart
(
content
string
)
int
{
minIndex
:=
-
1
start
:=
0
for
{
idx
:=
strings
.
Index
(
content
[
start
:
],
"!["
)
if
idx
<
0
{
break
}
idx
+=
start
if
!
hasSoraImageMatchAt
(
content
,
idx
)
{
if
minIndex
==
-
1
||
idx
<
minIndex
{
minIndex
=
idx
}
}
start
=
idx
+
2
}
lower
:=
strings
.
ToLower
(
content
)
start
=
0
for
{
idx
:=
strings
.
Index
(
lower
[
start
:
],
"<video"
)
if
idx
<
0
{
break
}
idx
+=
start
if
!
hasSoraVideoMatchAt
(
content
,
idx
)
{
if
minIndex
==
-
1
||
idx
<
minIndex
{
minIndex
=
idx
}
}
start
=
idx
+
len
(
"<video"
)
}
return
minIndex
}
func
hasSoraImageMatchAt
(
content
string
,
idx
int
)
bool
{
if
idx
<
0
||
idx
>=
len
(
content
)
{
return
false
}
loc
:=
soraImageMarkdownRe
.
FindStringIndex
(
content
[
idx
:
])
return
loc
!=
nil
&&
loc
[
0
]
==
0
}
func
hasSoraVideoMatchAt
(
content
string
,
idx
int
)
bool
{
if
idx
<
0
||
idx
>=
len
(
content
)
{
return
false
}
loc
:=
soraVideoHTMLRe
.
FindStringIndex
(
content
[
idx
:
])
return
loc
!=
nil
&&
loc
[
0
]
==
0
}
func
(
s
*
SoraGatewayService
)
rewriteSoraContent
(
content
string
)
string
{
if
content
==
""
{
return
content
}
content
=
soraImageMarkdownRe
.
ReplaceAllStringFunc
(
content
,
func
(
match
string
)
string
{
sub
:=
soraImageMarkdownRe
.
FindStringSubmatch
(
match
)
if
len
(
sub
)
<
2
{
return
match
}
rewritten
:=
s
.
rewriteSoraURL
(
sub
[
1
])
if
rewritten
==
sub
[
1
]
{
return
match
}
return
strings
.
Replace
(
match
,
sub
[
1
],
rewritten
,
1
)
})
content
=
soraVideoHTMLRe
.
ReplaceAllStringFunc
(
content
,
func
(
match
string
)
string
{
sub
:=
soraVideoHTMLRe
.
FindStringSubmatch
(
match
)
if
len
(
sub
)
<
2
{
return
match
}
rewritten
:=
s
.
rewriteSoraURL
(
sub
[
1
])
if
rewritten
==
sub
[
1
]
{
return
match
}
return
strings
.
Replace
(
match
,
sub
[
1
],
rewritten
,
1
)
})
return
content
}
func
(
s
*
SoraGatewayService
)
flushSoraRewriteBuffer
(
buffer
string
,
originalModel
string
)
(
string
,
string
,
error
)
{
if
buffer
==
""
{
return
""
,
""
,
nil
}
rewritten
:=
s
.
rewriteSoraContent
(
buffer
)
payload
:=
map
[
string
]
any
{
"choices"
:
[]
any
{
map
[
string
]
any
{
"delta"
:
map
[
string
]
any
{
"content"
:
rewritten
,
},
"index"
:
0
,
},
},
}
if
originalModel
!=
""
{
payload
[
"model"
]
=
originalModel
}
updatedData
,
err
:=
jsonMarshalRaw
(
payload
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
"data: "
+
string
(
updatedData
),
rewritten
,
nil
}
func
(
s
*
SoraGatewayService
)
rewriteSoraURL
(
raw
string
)
string
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
raw
}
path
:=
parsed
.
Path
if
!
strings
.
HasPrefix
(
path
,
"/tmp/"
)
&&
!
strings
.
HasPrefix
(
path
,
"/static/"
)
{
return
raw
}
return
s
.
buildSoraMediaURL
(
path
,
parsed
.
RawQuery
)
}
func
(
s
*
SoraGatewayService
)
extractSoraMedia
(
content
string
)
(
string
,
[]
string
)
{
if
content
==
""
{
return
""
,
nil
}
if
match
:=
soraVideoHTMLRe
.
FindStringSubmatch
(
content
);
len
(
match
)
>
1
{
return
"video"
,
[]
string
{
match
[
1
]}
}
imageMatches
:=
soraImageMarkdownRe
.
FindAllStringSubmatch
(
content
,
-
1
)
if
len
(
imageMatches
)
==
0
{
return
""
,
nil
}
urls
:=
make
([]
string
,
0
,
len
(
imageMatches
))
for
_
,
match
:=
range
imageMatches
{
if
len
(
match
)
>
1
{
urls
=
append
(
urls
,
match
[
1
])
}
}
return
"image"
,
urls
}
func
isSoraPromptEnhanceModel
(
model
string
)
bool
{
return
strings
.
HasPrefix
(
strings
.
ToLower
(
strings
.
TrimSpace
(
model
)),
"prompt-enhance"
)
}
backend/internal/service/sora_generation.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type
SoraGeneration
struct
{
ID
int64
`json:"id"`
UserID
int64
`json:"user_id"`
APIKeyID
*
int64
`json:"api_key_id,omitempty"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
MediaType
string
`json:"media_type"`
// video / image
Status
string
`json:"status"`
// pending / generating / completed / failed / cancelled
MediaURL
string
`json:"media_url"`
// 主媒体 URL(预签名或 CDN)
MediaURLs
[]
string
`json:"media_urls"`
// 多图时的 URL 数组
FileSizeBytes
int64
`json:"file_size_bytes"`
StorageType
string
`json:"storage_type"`
// s3 / local / upstream / none
S3ObjectKeys
[]
string
`json:"s3_object_keys"`
// S3 object key 数组
UpstreamTaskID
string
`json:"upstream_task_id"`
ErrorMessage
string
`json:"error_message"`
CreatedAt
time
.
Time
`json:"created_at"`
CompletedAt
*
time
.
Time
`json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const
(
SoraGenStatusPending
=
"pending"
SoraGenStatusGenerating
=
"generating"
SoraGenStatusCompleted
=
"completed"
SoraGenStatusFailed
=
"failed"
SoraGenStatusCancelled
=
"cancelled"
)
// Sora 存储类型常量
const
(
SoraStorageTypeS3
=
"s3"
SoraStorageTypeLocal
=
"local"
SoraStorageTypeUpstream
=
"upstream"
SoraStorageTypeNone
=
"none"
)
// SoraGenerationListParams 查询生成记录的参数。
type
SoraGenerationListParams
struct
{
UserID
int64
Status
string
// 可选筛选
StorageType
string
// 可选筛选
MediaType
string
// 可选筛选
Page
int
PageSize
int
}
// SoraGenerationRepository 生成记录持久化接口。
type
SoraGenerationRepository
interface
{
Create
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
SoraGeneration
,
error
)
Update
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
CountByUserAndStatus
(
ctx
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
}
backend/internal/service/sora_generation_service.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var
(
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit
=
errors
.
New
(
"sora generation concurrent limit exceeded"
)
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict
=
errors
.
New
(
"sora generation state conflict"
)
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive
=
errors
.
New
(
"sora generation is not active"
)
)
const
soraGenerationActiveLimit
=
3
type
soraGenerationRepoAtomicCreator
interface
{
CreatePendingWithLimit
(
ctx
context
.
Context
,
gen
*
SoraGeneration
,
activeStatuses
[]
string
,
maxActive
int64
)
error
}
type
soraGenerationRepoConditionalUpdater
interface
{
UpdateGeneratingIfPending
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
(
bool
,
error
)
UpdateCompletedIfActive
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateFailedIfActive
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateCancelledIfActive
(
ctx
context
.
Context
,
id
int64
,
completedAt
time
.
Time
)
(
bool
,
error
)
UpdateStorageIfCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
)
(
bool
,
error
)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type
SoraGenerationService
struct
{
genRepo
SoraGenerationRepository
s3Storage
*
SoraS3Storage
quotaService
*
SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func
NewSoraGenerationService
(
genRepo
SoraGenerationRepository
,
s3Storage
*
SoraS3Storage
,
quotaService
*
SoraQuotaService
,
)
*
SoraGenerationService
{
return
&
SoraGenerationService
{
genRepo
:
genRepo
,
s3Storage
:
s3Storage
,
quotaService
:
quotaService
,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func
(
s
*
SoraGenerationService
)
CreatePending
(
ctx
context
.
Context
,
userID
int64
,
apiKeyID
*
int64
,
model
,
prompt
,
mediaType
string
)
(
*
SoraGeneration
,
error
)
{
gen
:=
&
SoraGeneration
{
UserID
:
userID
,
APIKeyID
:
apiKeyID
,
Model
:
model
,
Prompt
:
prompt
,
MediaType
:
mediaType
,
Status
:
SoraGenStatusPending
,
StorageType
:
SoraStorageTypeNone
,
}
if
atomicCreator
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoAtomicCreator
);
ok
{
if
err
:=
atomicCreator
.
CreatePendingWithLimit
(
ctx
,
gen
,
[]
string
{
SoraGenStatusPending
,
SoraGenStatusGenerating
},
soraGenerationActiveLimit
,
);
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSoraGenerationConcurrencyLimit
)
{
return
nil
,
err
}
return
nil
,
fmt
.
Errorf
(
"create generation: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 创建记录 id=%d user=%d model=%s"
,
gen
.
ID
,
userID
,
model
)
return
gen
,
nil
}
if
err
:=
s
.
genRepo
.
Create
(
ctx
,
gen
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create generation: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 创建记录 id=%d user=%d model=%s"
,
gen
.
ID
,
userID
,
model
)
return
gen
,
nil
}
// MarkGenerating 标记为生成中。
func
(
s
*
SoraGenerationService
)
MarkGenerating
(
ctx
context
.
Context
,
id
int64
,
upstreamTaskID
string
)
error
{
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateGeneratingIfPending
(
ctx
,
id
,
upstreamTaskID
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusGenerating
gen
.
UpstreamTaskID
=
upstreamTaskID
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkCompleted 标记为已完成。
func
(
s
*
SoraGenerationService
)
MarkCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateCompletedIfActive
(
ctx
,
id
,
mediaURL
,
mediaURLs
,
storageType
,
s3Keys
,
fileSizeBytes
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusCompleted
gen
.
MediaURL
=
mediaURL
gen
.
MediaURLs
=
mediaURLs
gen
.
StorageType
=
storageType
gen
.
S3ObjectKeys
=
s3Keys
gen
.
FileSizeBytes
=
fileSizeBytes
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkFailed 标记为失败。
func
(
s
*
SoraGenerationService
)
MarkFailed
(
ctx
context
.
Context
,
id
int64
,
errMsg
string
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateFailedIfActive
(
ctx
,
id
,
errMsg
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationStateConflict
}
gen
.
Status
=
SoraGenStatusFailed
gen
.
ErrorMessage
=
errMsg
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// MarkCancelled 标记为已取消。
func
(
s
*
SoraGenerationService
)
MarkCancelled
(
ctx
context
.
Context
,
id
int64
)
error
{
now
:=
time
.
Now
()
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateCancelledIfActive
(
ctx
,
id
,
now
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationNotActive
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusPending
&&
gen
.
Status
!=
SoraGenStatusGenerating
{
return
ErrSoraGenerationNotActive
}
gen
.
Status
=
SoraGenStatusCancelled
gen
.
CompletedAt
=
&
now
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
func
(
s
*
SoraGenerationService
)
UpdateStorageForCompleted
(
ctx
context
.
Context
,
id
int64
,
mediaURL
string
,
mediaURLs
[]
string
,
storageType
string
,
s3Keys
[]
string
,
fileSizeBytes
int64
,
)
error
{
if
updater
,
ok
:=
s
.
genRepo
.
(
soraGenerationRepoConditionalUpdater
);
ok
{
updated
,
err
:=
updater
.
UpdateStorageIfCompleted
(
ctx
,
id
,
mediaURL
,
mediaURLs
,
storageType
,
s3Keys
,
fileSizeBytes
)
if
err
!=
nil
{
return
err
}
if
!
updated
{
return
ErrSoraGenerationStateConflict
}
return
nil
}
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
Status
!=
SoraGenStatusCompleted
{
return
ErrSoraGenerationStateConflict
}
gen
.
MediaURL
=
mediaURL
gen
.
MediaURLs
=
mediaURLs
gen
.
StorageType
=
storageType
gen
.
S3ObjectKeys
=
s3Keys
gen
.
FileSizeBytes
=
fileSizeBytes
return
s
.
genRepo
.
Update
(
ctx
,
gen
)
}
// GetByID 获取记录详情(含权限校验)。
func
(
s
*
SoraGenerationService
)
GetByID
(
ctx
context
.
Context
,
id
,
userID
int64
)
(
*
SoraGeneration
,
error
)
{
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
gen
.
UserID
!=
userID
{
return
nil
,
fmt
.
Errorf
(
"无权访问此生成记录"
)
}
return
gen
,
nil
}
// List 查询生成记录列表(分页 + 筛选)。
func
(
s
*
SoraGenerationService
)
List
(
ctx
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
{
if
params
.
Page
<=
0
{
params
.
Page
=
1
}
if
params
.
PageSize
<=
0
{
params
.
PageSize
=
20
}
if
params
.
PageSize
>
100
{
params
.
PageSize
=
100
}
return
s
.
genRepo
.
List
(
ctx
,
params
)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func
(
s
*
SoraGenerationService
)
Delete
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
gen
,
err
:=
s
.
genRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
gen
.
UserID
!=
userID
{
return
fmt
.
Errorf
(
"无权删除此生成记录"
)
}
// 清理 S3 文件
if
gen
.
StorageType
==
SoraStorageTypeS3
&&
len
(
gen
.
S3ObjectKeys
)
>
0
&&
s
.
s3Storage
!=
nil
{
if
err
:=
s
.
s3Storage
.
DeleteObjects
(
ctx
,
gen
.
S3ObjectKeys
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] S3 清理失败 id=%d err=%v"
,
id
,
err
)
}
}
// 释放配额(S3/本地均释放)
if
gen
.
FileSizeBytes
>
0
&&
(
gen
.
StorageType
==
SoraStorageTypeS3
||
gen
.
StorageType
==
SoraStorageTypeLocal
)
&&
s
.
quotaService
!=
nil
{
if
err
:=
s
.
quotaService
.
ReleaseUsage
(
ctx
,
userID
,
gen
.
FileSizeBytes
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_gen"
,
"[SoraGen] 配额释放失败 id=%d err=%v"
,
id
,
err
)
}
}
return
s
.
genRepo
.
Delete
(
ctx
,
id
)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func
(
s
*
SoraGenerationService
)
CountActiveByUser
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
return
s
.
genRepo
.
CountByUserAndStatus
(
ctx
,
userID
,
[]
string
{
SoraGenStatusPending
,
SoraGenStatusGenerating
})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func
(
s
*
SoraGenerationService
)
ResolveMediaURLs
(
ctx
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
gen
==
nil
||
gen
.
StorageType
!=
SoraStorageTypeS3
||
s
.
s3Storage
==
nil
{
return
nil
}
if
len
(
gen
.
S3ObjectKeys
)
==
0
{
return
nil
}
urls
:=
make
([]
string
,
len
(
gen
.
S3ObjectKeys
))
var
wg
sync
.
WaitGroup
var
firstErr
error
var
errMu
sync
.
Mutex
for
idx
,
key
:=
range
gen
.
S3ObjectKeys
{
wg
.
Add
(
1
)
go
func
(
i
int
,
objectKey
string
)
{
defer
wg
.
Done
()
url
,
err
:=
s
.
s3Storage
.
GetAccessURL
(
ctx
,
objectKey
)
if
err
!=
nil
{
errMu
.
Lock
()
if
firstErr
==
nil
{
firstErr
=
err
}
errMu
.
Unlock
()
return
}
urls
[
i
]
=
url
}(
idx
,
key
)
}
wg
.
Wait
()
if
firstErr
!=
nil
{
return
firstErr
}
gen
.
MediaURL
=
urls
[
0
]
gen
.
MediaURLs
=
urls
return
nil
}
backend/internal/service/sora_generation_service_test.go
deleted
100644 → 0
View file @
f585a15e
//go:build unit
package
service
import
(
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// ==================== Stub: SoraGenerationRepository ====================
var
_
SoraGenerationRepository
=
(
*
stubGenRepo
)(
nil
)
type
stubGenRepo
struct
{
gens
map
[
int64
]
*
SoraGeneration
nextID
int64
createErr
error
getErr
error
updateErr
error
deleteErr
error
listErr
error
countErr
error
countValue
int64
}
func
newStubGenRepo
()
*
stubGenRepo
{
return
&
stubGenRepo
{
gens
:
make
(
map
[
int64
]
*
SoraGeneration
),
nextID
:
1
}
}
func
(
r
*
stubGenRepo
)
Create
(
_
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
r
.
createErr
!=
nil
{
return
r
.
createErr
}
gen
.
ID
=
r
.
nextID
gen
.
CreatedAt
=
time
.
Now
()
r
.
nextID
++
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubGenRepo
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
SoraGeneration
,
error
)
{
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
if
gen
,
ok
:=
r
.
gens
[
id
];
ok
{
return
gen
,
nil
}
return
nil
,
fmt
.
Errorf
(
"not found"
)
}
func
(
r
*
stubGenRepo
)
Update
(
_
context
.
Context
,
gen
*
SoraGeneration
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
gens
[
gen
.
ID
]
=
gen
return
nil
}
func
(
r
*
stubGenRepo
)
Delete
(
_
context
.
Context
,
id
int64
)
error
{
if
r
.
deleteErr
!=
nil
{
return
r
.
deleteErr
}
delete
(
r
.
gens
,
id
)
return
nil
}
func
(
r
*
stubGenRepo
)
List
(
_
context
.
Context
,
params
SoraGenerationListParams
)
([]
*
SoraGeneration
,
int64
,
error
)
{
if
r
.
listErr
!=
nil
{
return
nil
,
0
,
r
.
listErr
}
var
result
[]
*
SoraGeneration
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
!=
params
.
UserID
{
continue
}
if
params
.
Status
!=
""
&&
gen
.
Status
!=
params
.
Status
{
continue
}
if
params
.
StorageType
!=
""
&&
gen
.
StorageType
!=
params
.
StorageType
{
continue
}
if
params
.
MediaType
!=
""
&&
gen
.
MediaType
!=
params
.
MediaType
{
continue
}
result
=
append
(
result
,
gen
)
}
return
result
,
int64
(
len
(
result
)),
nil
}
func
(
r
*
stubGenRepo
)
CountByUserAndStatus
(
_
context
.
Context
,
userID
int64
,
statuses
[]
string
)
(
int64
,
error
)
{
if
r
.
countErr
!=
nil
{
return
0
,
r
.
countErr
}
if
r
.
countValue
>
0
{
return
r
.
countValue
,
nil
}
var
count
int64
statusSet
:=
make
(
map
[
string
]
struct
{})
for
_
,
s
:=
range
statuses
{
statusSet
[
s
]
=
struct
{}{}
}
for
_
,
gen
:=
range
r
.
gens
{
if
gen
.
UserID
==
userID
{
if
_
,
ok
:=
statusSet
[
gen
.
Status
];
ok
{
count
++
}
}
}
return
count
,
nil
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var
_
UserRepository
=
(
*
stubUserRepoForQuota
)(
nil
)
type
stubUserRepoForQuota
struct
{
users
map
[
int64
]
*
User
updateErr
error
}
func
newStubUserRepoForQuota
()
*
stubUserRepoForQuota
{
return
&
stubUserRepoForQuota
{
users
:
make
(
map
[
int64
]
*
User
)}
}
func
(
r
*
stubUserRepoForQuota
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
User
,
error
)
{
if
u
,
ok
:=
r
.
users
[
id
];
ok
{
return
u
,
nil
}
return
nil
,
fmt
.
Errorf
(
"user not found"
)
}
func
(
r
*
stubUserRepoForQuota
)
Update
(
_
context
.
Context
,
user
*
User
)
error
{
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
users
[
user
.
ID
]
=
user
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
Create
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
GetByEmail
(
context
.
Context
,
string
)
(
*
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
GetFirstAdmin
(
context
.
Context
)
(
*
User
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
ExistsByEmail
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubUserRepoForQuota
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubUserRepoForQuota
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
func
newS3StorageWithCDN
(
cdnURL
string
)
*
SoraS3Storage
{
storage
:=
&
SoraS3Storage
{}
storage
.
cfg
=
&
SoraS3Settings
{
Enabled
:
true
,
Bucket
:
"test-bucket"
,
CDNURL
:
cdnURL
,
}
// 需要 non-nil client 使 getClient 命中缓存
storage
.
client
=
s3
.
New
(
s3
.
Options
{})
return
storage
}
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
func
newS3StorageFailingDelete
()
*
SoraS3Storage
{
return
&
SoraS3Storage
{}
// settingService 为 nil → getConfig 返回 error
}
// ==================== CreatePending ====================
func
TestCreatePending_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"一只猫跳舞"
,
"video"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
UserID
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
gen
.
Model
)
require
.
Equal
(
t
,
"一只猫跳舞"
,
gen
.
Prompt
)
require
.
Equal
(
t
,
"video"
,
gen
.
MediaType
)
require
.
Equal
(
t
,
SoraGenStatusPending
,
gen
.
Status
)
require
.
Equal
(
t
,
SoraStorageTypeNone
,
gen
.
StorageType
)
require
.
Nil
(
t
,
gen
.
APIKeyID
)
}
func
TestCreatePending_WithAPIKeyID
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
apiKeyID
:=
int64
(
42
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
&
apiKeyID
,
"gpt-image"
,
"画一朵花"
,
"image"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
gen
.
APIKeyID
)
require
.
Equal
(
t
,
int64
(
42
),
*
gen
.
APIKeyID
)
}
func
TestCreatePending_RepoError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
createErr
=
fmt
.
Errorf
(
"db write error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
require
.
Contains
(
t
,
err
.
Error
(),
"create generation"
)
}
// ==================== MarkGenerating ====================
func
TestMarkGenerating_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
1
,
"upstream-task-123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusGenerating
,
repo
.
gens
[
1
]
.
Status
)
require
.
Equal
(
t
,
"upstream-task-123"
,
repo
.
gens
[
1
]
.
UpstreamTaskID
)
}
func
TestMarkGenerating_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
999
,
""
)
require
.
Error
(
t
,
err
)
}
func
TestMarkGenerating_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkGenerating
(
context
.
Background
(),
1
,
""
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkCompleted ====================
func
TestMarkCompleted_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
1
,
"https://cdn.example.com/video.mp4"
,
[]
string
{
"https://cdn.example.com/video.mp4"
},
SoraStorageTypeS3
,
[]
string
{
"sora/1/2024/01/01/uuid.mp4"
},
1048576
,
)
require
.
NoError
(
t
,
err
)
gen
:=
repo
.
gens
[
1
]
require
.
Equal
(
t
,
SoraGenStatusCompleted
,
gen
.
Status
)
require
.
Equal
(
t
,
"https://cdn.example.com/video.mp4"
,
gen
.
MediaURL
)
require
.
Equal
(
t
,
[]
string
{
"https://cdn.example.com/video.mp4"
},
gen
.
MediaURLs
)
require
.
Equal
(
t
,
SoraStorageTypeS3
,
gen
.
StorageType
)
require
.
Equal
(
t
,
[]
string
{
"sora/1/2024/01/01/uuid.mp4"
},
gen
.
S3ObjectKeys
)
require
.
Equal
(
t
,
int64
(
1048576
),
gen
.
FileSizeBytes
)
require
.
NotNil
(
t
,
gen
.
CompletedAt
)
}
func
TestMarkCompleted_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
999
,
""
,
nil
,
""
,
nil
,
0
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCompleted_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCompleted
(
context
.
Background
(),
1
,
"url"
,
nil
,
SoraStorageTypeUpstream
,
nil
,
0
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkFailed ====================
func
TestMarkFailed_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
1
,
"上游返回 500 错误"
)
require
.
NoError
(
t
,
err
)
gen
:=
repo
.
gens
[
1
]
require
.
Equal
(
t
,
SoraGenStatusFailed
,
gen
.
Status
)
require
.
Equal
(
t
,
"上游返回 500 错误"
,
gen
.
ErrorMessage
)
require
.
NotNil
(
t
,
gen
.
CompletedAt
)
}
func
TestMarkFailed_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
999
,
"error"
)
require
.
Error
(
t
,
err
)
}
func
TestMarkFailed_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
1
,
"err"
)
require
.
Error
(
t
,
err
)
}
// ==================== MarkCancelled ====================
func
TestMarkCancelled_Pending
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
1
]
.
Status
)
require
.
NotNil
(
t
,
repo
.
gens
[
1
]
.
CompletedAt
)
}
func
TestMarkCancelled_Generating
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
1
]
.
Status
)
}
func
TestMarkCancelled_Completed
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
ErrSoraGenerationNotActive
)
}
func
TestMarkCancelled_Failed
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusFailed
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_AlreadyCancelled
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCancelled
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
999
)
require
.
Error
(
t
,
err
)
}
func
TestMarkCancelled_UpdateError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
updateErr
=
fmt
.
Errorf
(
"update failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
// ==================== GetByID ====================
func
TestGetByID_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
Model
:
"sora2-landscape-10s"
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
gen
.
ID
)
require
.
Equal
(
t
,
"sora2-landscape-10s"
,
gen
.
Model
)
}
func
TestGetByID_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权访问"
)
}
func
TestGetByID_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gen
)
}
// ==================== List ====================
func
TestList_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
MediaType
:
"video"
}
repo
.
gens
[
2
]
=
&
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Status
:
SoraGenStatusPending
,
MediaType
:
"image"
}
repo
.
gens
[
3
]
=
&
SoraGeneration
{
ID
:
3
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
,
MediaType
:
"video"
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gens
,
total
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
,
Page
:
1
,
PageSize
:
20
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
gens
,
2
)
// 只有 userID=1 的
require
.
Equal
(
t
,
int64
(
2
),
total
)
}
func
TestList_DefaultPagination
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
})
require
.
NoError
(
t
,
err
)
}
func
TestList_MaxPageSize
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// pageSize > 100 → 应限制为 100
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
,
Page
:
1
,
PageSize
:
200
})
require
.
NoError
(
t
,
err
)
}
func
TestList_Error
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
listErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
_
,
_
,
err
:=
svc
.
List
(
context
.
Background
(),
SoraGenerationListParams
{
UserID
:
1
})
require
.
Error
(
t
,
err
)
}
// ==================== Delete ====================
func
TestDelete_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
,
StorageType
:
SoraStorageTypeUpstream
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDelete_WrongUser
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
2
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权删除"
)
}
func
TestDelete_NotFound
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
999
,
1
)
require
.
Error
(
t
,
err
)
}
func
TestDelete_S3Cleanup_NilS3
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
}}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// s3Storage 为 nil,跳过清理
}
func
TestDelete_QuotaRelease_NilQuota
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
1024
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// quotaService 为 nil,跳过释放
}
func
TestDelete_NonS3NoCleanup
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeLocal
,
FileSizeBytes
:
1024
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
}
func
TestDelete_DeleteRepoError
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeUpstream
}
repo
.
deleteErr
=
fmt
.
Errorf
(
"delete failed"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
}
// ==================== CountActiveByUser ====================
func
TestCountActiveByUser_Success
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusPending
}
repo
.
gens
[
2
]
=
&
SoraGeneration
{
ID
:
2
,
UserID
:
1
,
Status
:
SoraGenStatusGenerating
}
repo
.
gens
[
3
]
=
&
SoraGeneration
{
ID
:
3
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
// 不算
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
count
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
count
)
}
func
TestCountActiveByUser_NoActive
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
Status
:
SoraGenStatusCompleted
}
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
count
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
count
)
}
func
TestCountActiveByUser_Error
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
repo
.
countErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
_
,
err
:=
svc
.
CountActiveByUser
(
context
.
Background
(),
1
)
require
.
Error
(
t
,
err
)
}
// ==================== ResolveMediaURLs ====================
func
TestResolveMediaURLs_NilGen
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
nil
))
}
func
TestResolveMediaURLs_NonS3
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeUpstream
,
MediaURL
:
"https://original.com/v.mp4"
}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
require
.
Equal
(
t
,
"https://original.com/v.mp4"
,
gen
.
MediaURL
)
// 不变
}
func
TestResolveMediaURLs_S3NilStorage
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
}}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
}
func
TestResolveMediaURLs_Local
(
t
*
testing
.
T
)
{
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
nil
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeLocal
,
MediaURL
:
"/video/2024/01/01/file.mp4"
}
require
.
NoError
(
t
,
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
))
require
.
Equal
(
t
,
"/video/2024/01/01/file.mp4"
,
gen
.
MediaURL
)
// 不变
}
// ==================== 状态流转完整测试 ====================
func
TestStatusTransition_PendingToCompletedFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
// 1. 创建 pending
gen
,
err
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusPending
,
gen
.
Status
)
// 2. 标记 generating
err
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
"task-123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusGenerating
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
// 3. 标记 completed
err
=
svc
.
MarkCompleted
(
context
.
Background
(),
gen
.
ID
,
"https://s3.com/video.mp4"
,
nil
,
SoraStorageTypeS3
,
[]
string
{
"key"
},
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCompleted
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
func
TestStatusTransition_PendingToFailedFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
_
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
""
)
err
:=
svc
.
MarkFailed
(
context
.
Background
(),
gen
.
ID
,
"上游超时"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusFailed
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
require
.
Equal
(
t
,
"上游超时"
,
repo
.
gens
[
gen
.
ID
]
.
ErrorMessage
)
}
func
TestStatusTransition_PendingToCancelledFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
gen
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
func
TestStatusTransition_GeneratingToCancelledFlow
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
_
=
svc
.
MarkGenerating
(
context
.
Background
(),
gen
.
ID
,
""
)
err
:=
svc
.
MarkCancelled
(
context
.
Background
(),
gen
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
SoraGenStatusCancelled
,
repo
.
gens
[
gen
.
ID
]
.
Status
)
}
// ==================== 权限隔离测试 ====================
func
TestUserIsolation_CannotAccessOthersRecord
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
// 用户 2 尝试访问用户 1 的记录
_
,
err
:=
svc
.
GetByID
(
context
.
Background
(),
gen
.
ID
,
2
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权访问"
)
}
func
TestUserIsolation_CannotDeleteOthersRecord
(
t
*
testing
.
T
)
{
repo
:=
newStubGenRepo
()
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
nil
)
gen
,
_
:=
svc
.
CreatePending
(
context
.
Background
(),
1
,
nil
,
"sora2-landscape-10s"
,
"test"
,
"video"
)
err
:=
svc
.
Delete
(
context
.
Background
(),
gen
.
ID
,
2
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"无权删除"
)
}
// ==================== Delete: S3 清理 + 配额释放路径 ====================
func
TestDelete_S3Cleanup_WithS3Storage
(
t
*
testing
.
T
)
{
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
// 验证 Delete 仍然成功(S3 错误只是记录日志)
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/abc.mp4"
},
}
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
repo
,
s3Storage
,
nil
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// S3 清理失败不影响删除
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
}
func
TestDelete_QuotaRelease_WithQuotaService
(
t
*
testing
.
T
)
{
// 有配额服务时,删除 S3 类型记录会释放配额
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
1048576
,
// 1MB
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
2097152
}
// 2MB
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
// 配额应被释放: 2MB - 1MB = 1MB
require
.
Equal
(
t
,
int64
(
1048576
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_S3Cleanup_And_QuotaRelease
(
t
*
testing
.
T
)
{
// S3 清理 + 配额释放同时触发
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"key1"
},
FileSizeBytes
:
512
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
repo
,
s3Storage
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
_
,
exists
:=
repo
.
gens
[
1
]
require
.
False
(
t
,
exists
)
require
.
Equal
(
t
,
int64
(
512
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_QuotaRelease_LocalStorage
(
t
*
testing
.
T
)
{
// 本地存储同样需要释放配额
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeLocal
,
FileSizeBytes
:
1024
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
2048
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestDelete_QuotaRelease_ZeroFileSize
(
t
*
testing
.
T
)
{
// FileSizeBytes=0 跳过配额释放
repo
:=
newStubGenRepo
()
repo
.
gens
[
1
]
=
&
SoraGeneration
{
ID
:
1
,
UserID
:
1
,
StorageType
:
SoraStorageTypeS3
,
FileSizeBytes
:
0
,
}
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
quotaService
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
svc
:=
NewSoraGenerationService
(
repo
,
nil
,
quotaService
)
err
:=
svc
.
Delete
(
context
.
Background
(),
1
,
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
func
TestResolveMediaURLs_S3_CDN_SingleKey
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/video.mp4"
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/video.mp4"
,
gen
.
MediaURL
)
}
func
TestResolveMediaURLs_S3_CDN_MultipleKeys
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com/"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/img1.png"
,
"sora/1/2024/01/01/img2.png"
,
"sora/1/2024/01/01/img3.png"
,
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
// 主 URL 更新为第一个 key 的 CDN URL
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img1.png"
,
gen
.
MediaURL
)
// 多图 URLs 全部更新
require
.
Len
(
t
,
gen
.
MediaURLs
,
3
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img1.png"
,
gen
.
MediaURLs
[
0
])
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img2.png"
,
gen
.
MediaURLs
[
1
])
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/img3.png"
,
gen
.
MediaURLs
[
2
])
}
func
TestResolveMediaURLs_S3_EmptyKeys
(
t
*
testing
.
T
)
{
s3Storage
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"original"
,
gen
.
MediaURL
)
// 不变
}
func
TestResolveMediaURLs_S3_GetAccessURL_Error
(
t
*
testing
.
T
)
{
// 使用无 settingService 的 S3 Storage,getClient 会失败
s3Storage
:=
newS3StorageFailingDelete
()
// 同样 GetAccessURL 也会失败
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/video.mp4"
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
Error
(
t
,
err
)
// GetAccessURL 失败应传播错误
}
func
TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond
(
t
*
testing
.
T
)
{
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
s3Storage
:=
newS3StorageFailingDelete
()
svc
:=
NewSoraGenerationService
(
newStubGenRepo
(),
s3Storage
,
nil
)
gen
:=
&
SoraGeneration
{
StorageType
:
SoraStorageTypeS3
,
S3ObjectKeys
:
[]
string
{
"sora/1/2024/01/01/img1.png"
,
"sora/1/2024/01/01/img2.png"
,
},
MediaURL
:
"original"
,
}
err
:=
svc
.
ResolveMediaURLs
(
context
.
Background
(),
gen
)
require
.
Error
(
t
,
err
)
// 第一个 key 的 GetAccessURL 就会失败
}
backend/internal/service/sora_media_cleanup_service.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
var
soraCleanupCronParser
=
cron
.
NewParser
(
cron
.
Minute
|
cron
.
Hour
|
cron
.
Dom
|
cron
.
Month
|
cron
.
Dow
)
// SoraMediaCleanupService 定期清理本地媒体文件
type
SoraMediaCleanupService
struct
{
storage
*
SoraMediaStorage
cfg
*
config
.
Config
cron
*
cron
.
Cron
startOnce
sync
.
Once
stopOnce
sync
.
Once
}
func
NewSoraMediaCleanupService
(
storage
*
SoraMediaStorage
,
cfg
*
config
.
Config
)
*
SoraMediaCleanupService
{
return
&
SoraMediaCleanupService
{
storage
:
storage
,
cfg
:
cfg
,
}
}
func
(
s
*
SoraMediaCleanupService
)
Start
()
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
}
if
!
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
Enabled
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (disabled)"
)
return
}
if
s
.
storage
==
nil
||
!
s
.
storage
.
Enabled
()
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (storage disabled)"
)
return
}
s
.
startOnce
.
Do
(
func
()
{
schedule
:=
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
Schedule
)
if
schedule
==
""
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (empty schedule)"
)
return
}
loc
:=
time
.
Local
if
strings
.
TrimSpace
(
s
.
cfg
.
Timezone
)
!=
""
{
if
parsed
,
err
:=
time
.
LoadLocation
(
strings
.
TrimSpace
(
s
.
cfg
.
Timezone
));
err
==
nil
&&
parsed
!=
nil
{
loc
=
parsed
}
}
c
:=
cron
.
New
(
cron
.
WithParser
(
soraCleanupCronParser
),
cron
.
WithLocation
(
loc
))
if
_
,
err
:=
c
.
AddFunc
(
schedule
,
func
()
{
s
.
runCleanup
()
});
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] not started (invalid schedule=%q): %v"
,
schedule
,
err
)
return
}
s
.
cron
=
c
s
.
cron
.
Start
()
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] started (schedule=%q tz=%s)"
,
schedule
,
loc
.
String
())
})
}
func
(
s
*
SoraMediaCleanupService
)
Stop
()
{
if
s
==
nil
{
return
}
s
.
stopOnce
.
Do
(
func
()
{
if
s
.
cron
!=
nil
{
ctx
:=
s
.
cron
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
case
<-
time
.
After
(
3
*
time
.
Second
)
:
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] cron stop timed out"
)
}
}
})
}
func
(
s
*
SoraMediaCleanupService
)
runCleanup
()
{
if
s
.
cfg
==
nil
||
s
.
storage
==
nil
{
return
}
retention
:=
s
.
cfg
.
Sora
.
Storage
.
Cleanup
.
RetentionDays
if
retention
<=
0
{
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] skipped (retention_days=%d)"
,
retention
)
return
}
cutoff
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
retention
)
deleted
:=
0
roots
:=
[]
string
{
s
.
storage
.
ImageRoot
(),
s
.
storage
.
VideoRoot
()}
for
_
,
root
:=
range
roots
{
if
root
==
""
{
continue
}
_
=
filepath
.
Walk
(
root
,
func
(
p
string
,
info
os
.
FileInfo
,
err
error
)
error
{
if
err
!=
nil
{
return
nil
}
if
info
.
IsDir
()
{
return
nil
}
if
info
.
ModTime
()
.
Before
(
cutoff
)
{
if
rmErr
:=
os
.
Remove
(
p
);
rmErr
==
nil
{
deleted
++
}
}
return
nil
})
}
logger
.
LegacyPrintf
(
"service.sora_media_cleanup"
,
"[SoraCleanup] cleanup finished, deleted=%d"
,
deleted
)
}
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment