Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
d757df8a
Unverified
Commit
d757df8a
authored
Apr 05, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 05, 2026
Browse files
Merge pull request #1463 from touwaeriol/feat/remove-sora
revert: completely remove Sora platform
parents
f585a15e
19655a15
Changes
163
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/sora_media_cleanup_service_test.go
deleted
100644 → 0
View file @
f585a15e
//go:build unit
package
service
import
(
"os"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSoraMediaCleanupService_RunCleanup_NilCfg
(
t
*
testing
.
T
)
{
storage
:=
&
SoraMediaStorage
{}
svc
:=
&
SoraMediaCleanupService
{
storage
:
storage
,
cfg
:
nil
}
// 不应 panic
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_RunCleanup_NilStorage
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
svc
:=
&
SoraMediaCleanupService
{
storage
:
nil
,
cfg
:
cfg
}
// 不应 panic
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_RunCleanup_ZeroRetention
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
RetentionDays
:
0
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
// retention=0 应跳过清理
svc
.
runCleanup
()
}
func
TestSoraMediaCleanupService_Start_NilCfg
(
t
*
testing
.
T
)
{
svc
:=
NewSoraMediaCleanupService
(
nil
,
nil
)
svc
.
Start
()
// cfg == nil 时应直接返回
}
func
TestSoraMediaCleanupService_Start_StorageDisabled
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
},
},
},
}
svc
:=
NewSoraMediaCleanupService
(
nil
,
cfg
)
svc
.
Start
()
// storage == nil 时应直接返回
}
func
TestSoraMediaCleanupService_Start_WithTimezone
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Timezone
:
"Asia/Shanghai"
,
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"0 3 * * *"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
}
func
TestSoraMediaCleanupService_Start_Disabled
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
false
,
},
},
},
}
svc
:=
NewSoraMediaCleanupService
(
nil
,
cfg
)
svc
.
Start
()
// 不应 panic,也不应启动 cron
}
func
TestSoraMediaCleanupService_Start_NilSelf
(
t
*
testing
.
T
)
{
var
svc
*
SoraMediaCleanupService
svc
.
Start
()
// 不应 panic
}
func
TestSoraMediaCleanupService_Start_EmptySchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
""
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
// 空 schedule 不应启动
}
func
TestSoraMediaCleanupService_Start_InvalidSchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"invalid-cron"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
// 无效 schedule 不应 panic
}
func
TestSoraMediaCleanupService_Start_ValidSchedule
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
Schedule
:
"0 3 * * *"
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
t
.
Cleanup
(
svc
.
Stop
)
}
func
TestSoraMediaCleanupService_Stop_NilSelf
(
t
*
testing
.
T
)
{
var
svc
*
SoraMediaCleanupService
svc
.
Stop
()
// 不应 panic
}
func
TestSoraMediaCleanupService_Stop_WithoutStart
(
t
*
testing
.
T
)
{
svc
:=
NewSoraMediaCleanupService
(
nil
,
&
config
.
Config
{})
svc
.
Stop
()
// cron 未启动时 Stop 不应 panic
}
func
TestSoraMediaCleanupService_RunCleanup
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
Cleanup
:
config
.
SoraStorageCleanupConfig
{
Enabled
:
true
,
RetentionDays
:
1
,
},
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
require
.
NoError
(
t
,
storage
.
EnsureLocalDirs
())
oldImage
:=
filepath
.
Join
(
storage
.
ImageRoot
(),
"old.png"
)
newVideo
:=
filepath
.
Join
(
storage
.
VideoRoot
(),
"new.mp4"
)
require
.
NoError
(
t
,
os
.
WriteFile
(
oldImage
,
[]
byte
(
"old"
),
0
o644
))
require
.
NoError
(
t
,
os
.
WriteFile
(
newVideo
,
[]
byte
(
"new"
),
0
o644
))
oldTime
:=
time
.
Now
()
.
Add
(
-
48
*
time
.
Hour
)
require
.
NoError
(
t
,
os
.
Chtimes
(
oldImage
,
oldTime
,
oldTime
))
cleanup
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
cleanup
.
runCleanup
()
require
.
NoFileExists
(
t
,
oldImage
)
require
.
FileExists
(
t
,
newVideo
)
}
backend/internal/service/sora_media_sign.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strconv"
"strings"
)
// SignSoraMediaURL 生成 Sora 媒体临时签名
func
SignSoraMediaURL
(
path
string
,
query
string
,
expires
int64
,
key
string
)
string
{
key
=
strings
.
TrimSpace
(
key
)
if
key
==
""
{
return
""
}
mac
:=
hmac
.
New
(
sha256
.
New
,
[]
byte
(
key
))
if
_
,
err
:=
mac
.
Write
([]
byte
(
buildSoraMediaSignPayload
(
path
,
query
)));
err
!=
nil
{
return
""
}
if
_
,
err
:=
mac
.
Write
([]
byte
(
"|"
));
err
!=
nil
{
return
""
}
if
_
,
err
:=
mac
.
Write
([]
byte
(
strconv
.
FormatInt
(
expires
,
10
)));
err
!=
nil
{
return
""
}
return
hex
.
EncodeToString
(
mac
.
Sum
(
nil
))
}
// VerifySoraMediaURL 校验 Sora 媒体签名
func
VerifySoraMediaURL
(
path
string
,
query
string
,
expires
int64
,
signature
string
,
key
string
)
bool
{
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
return
false
}
expected
:=
SignSoraMediaURL
(
path
,
query
,
expires
,
key
)
if
expected
==
""
{
return
false
}
return
hmac
.
Equal
([]
byte
(
signature
),
[]
byte
(
expected
))
}
func
buildSoraMediaSignPayload
(
path
string
,
query
string
)
string
{
if
strings
.
TrimSpace
(
query
)
==
""
{
return
path
}
return
path
+
"?"
+
query
}
backend/internal/service/sora_media_sign_test.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
"testing"
func
TestSoraMediaSignVerify
(
t
*
testing
.
T
)
{
key
:=
"test-key"
path
:=
"/tmp/abc.png"
query
:=
"a=1&b=2"
expires
:=
int64
(
1700000000
)
signature
:=
SignSoraMediaURL
(
path
,
query
,
expires
,
key
)
if
signature
==
""
{
t
.
Fatal
(
"签名为空"
)
}
if
!
VerifySoraMediaURL
(
path
,
query
,
expires
,
signature
,
key
)
{
t
.
Fatal
(
"签名校验失败"
)
}
if
VerifySoraMediaURL
(
path
,
"a=1"
,
expires
,
signature
,
key
)
{
t
.
Fatal
(
"签名参数不同仍然通过"
)
}
if
VerifySoraMediaURL
(
path
,
query
,
expires
+
1
,
signature
,
key
)
{
t
.
Fatal
(
"签名过期校验未失败"
)
}
}
func
TestSoraMediaSignWithEmptyKey
(
t
*
testing
.
T
)
{
signature
:=
SignSoraMediaURL
(
"/tmp/a.png"
,
"a=1"
,
1
,
""
)
if
signature
!=
""
{
t
.
Fatalf
(
"空密钥不应生成签名"
)
}
if
VerifySoraMediaURL
(
"/tmp/a.png"
,
"a=1"
,
1
,
"sig"
,
""
)
{
t
.
Fatalf
(
"空密钥不应通过校验"
)
}
}
backend/internal/service/sora_media_storage.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
)
const
(
soraStorageDefaultRoot
=
"/app/data/sora"
)
// SoraMediaStorage 负责下载并落地 Sora 媒体
type
SoraMediaStorage
struct
{
cfg
*
config
.
Config
root
string
imageRoot
string
videoRoot
string
downloadTimeout
time
.
Duration
maxDownloadBytes
int64
fallbackToUpstream
bool
debug
bool
sem
chan
struct
{}
ready
bool
}
func
NewSoraMediaStorage
(
cfg
*
config
.
Config
)
*
SoraMediaStorage
{
storage
:=
&
SoraMediaStorage
{
cfg
:
cfg
}
storage
.
refreshConfig
()
if
storage
.
Enabled
()
{
if
err
:=
storage
.
EnsureLocalDirs
();
err
!=
nil
{
log
.
Printf
(
"[SoraStorage] 初始化失败: %v"
,
err
)
}
}
return
storage
}
func
(
s
*
SoraMediaStorage
)
Enabled
()
bool
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
false
}
return
strings
.
ToLower
(
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
Type
))
==
"local"
}
func
(
s
*
SoraMediaStorage
)
Root
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
root
}
func
(
s
*
SoraMediaStorage
)
ImageRoot
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
imageRoot
}
func
(
s
*
SoraMediaStorage
)
VideoRoot
()
string
{
if
s
==
nil
{
return
""
}
return
s
.
videoRoot
}
func
(
s
*
SoraMediaStorage
)
refreshConfig
()
{
if
s
==
nil
||
s
.
cfg
==
nil
{
return
}
root
:=
strings
.
TrimSpace
(
s
.
cfg
.
Sora
.
Storage
.
LocalPath
)
if
root
==
""
{
root
=
soraStorageDefaultRoot
}
root
=
filepath
.
Clean
(
root
)
if
!
filepath
.
IsAbs
(
root
)
{
if
absRoot
,
err
:=
filepath
.
Abs
(
root
);
err
==
nil
{
root
=
absRoot
}
}
s
.
root
=
root
s
.
imageRoot
=
filepath
.
Join
(
root
,
"image"
)
s
.
videoRoot
=
filepath
.
Join
(
root
,
"video"
)
maxConcurrent
:=
s
.
cfg
.
Sora
.
Storage
.
MaxConcurrentDownloads
if
maxConcurrent
<=
0
{
maxConcurrent
=
4
}
timeoutSeconds
:=
s
.
cfg
.
Sora
.
Storage
.
DownloadTimeoutSeconds
if
timeoutSeconds
<=
0
{
timeoutSeconds
=
120
}
s
.
downloadTimeout
=
time
.
Duration
(
timeoutSeconds
)
*
time
.
Second
maxBytes
:=
s
.
cfg
.
Sora
.
Storage
.
MaxDownloadBytes
if
maxBytes
<=
0
{
maxBytes
=
200
<<
20
}
s
.
maxDownloadBytes
=
maxBytes
s
.
fallbackToUpstream
=
s
.
cfg
.
Sora
.
Storage
.
FallbackToUpstream
s
.
debug
=
s
.
cfg
.
Sora
.
Storage
.
Debug
s
.
sem
=
make
(
chan
struct
{},
maxConcurrent
)
}
// EnsureLocalDirs 创建并校验本地目录
func
(
s
*
SoraMediaStorage
)
EnsureLocalDirs
()
error
{
if
s
==
nil
||
!
s
.
Enabled
()
{
return
nil
}
if
err
:=
os
.
MkdirAll
(
s
.
imageRoot
,
0
o755
);
err
!=
nil
{
return
fmt
.
Errorf
(
"create image dir: %w"
,
err
)
}
if
err
:=
os
.
MkdirAll
(
s
.
videoRoot
,
0
o755
);
err
!=
nil
{
return
fmt
.
Errorf
(
"create video dir: %w"
,
err
)
}
s
.
ready
=
true
return
nil
}
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
func
(
s
*
SoraMediaStorage
)
StoreFromURLs
(
ctx
context
.
Context
,
mediaType
string
,
urls
[]
string
)
([]
string
,
error
)
{
if
len
(
urls
)
==
0
{
return
nil
,
nil
}
if
s
==
nil
||
!
s
.
Enabled
()
{
return
urls
,
nil
}
if
!
s
.
ready
{
if
err
:=
s
.
EnsureLocalDirs
();
err
!=
nil
{
return
nil
,
err
}
}
results
:=
make
([]
string
,
0
,
len
(
urls
))
for
_
,
raw
:=
range
urls
{
relative
,
err
:=
s
.
downloadAndStore
(
ctx
,
mediaType
,
raw
)
if
err
!=
nil
{
if
s
.
fallbackToUpstream
{
results
=
append
(
results
,
raw
)
continue
}
return
nil
,
err
}
results
=
append
(
results
,
relative
)
}
return
results
,
nil
}
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
func
(
s
*
SoraMediaStorage
)
TotalSizeByRelativePaths
(
paths
[]
string
)
(
int64
,
error
)
{
if
s
==
nil
||
len
(
paths
)
==
0
{
return
0
,
nil
}
var
total
int64
for
_
,
p
:=
range
paths
{
localPath
,
err
:=
s
.
resolveLocalPath
(
p
)
if
err
!=
nil
{
continue
}
info
,
err
:=
os
.
Stat
(
localPath
)
if
err
!=
nil
{
if
os
.
IsNotExist
(
err
)
{
continue
}
return
0
,
err
}
if
info
.
Mode
()
.
IsRegular
()
{
total
+=
info
.
Size
()
}
}
return
total
,
nil
}
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
func
(
s
*
SoraMediaStorage
)
DeleteByRelativePaths
(
paths
[]
string
)
error
{
if
s
==
nil
||
len
(
paths
)
==
0
{
return
nil
}
var
lastErr
error
for
_
,
p
:=
range
paths
{
localPath
,
err
:=
s
.
resolveLocalPath
(
p
)
if
err
!=
nil
{
continue
}
if
err
:=
os
.
Remove
(
localPath
);
err
!=
nil
&&
!
os
.
IsNotExist
(
err
)
{
lastErr
=
err
}
}
return
lastErr
}
func
(
s
*
SoraMediaStorage
)
resolveLocalPath
(
relativePath
string
)
(
string
,
error
)
{
if
s
==
nil
||
strings
.
TrimSpace
(
relativePath
)
==
""
{
return
""
,
errors
.
New
(
"empty path"
)
}
cleaned
:=
path
.
Clean
(
relativePath
)
if
!
strings
.
HasPrefix
(
cleaned
,
"/image/"
)
&&
!
strings
.
HasPrefix
(
cleaned
,
"/video/"
)
{
return
""
,
errors
.
New
(
"not a local media path"
)
}
if
strings
.
TrimSpace
(
s
.
root
)
==
""
{
return
""
,
errors
.
New
(
"storage root not configured"
)
}
relative
:=
strings
.
TrimPrefix
(
cleaned
,
"/"
)
return
filepath
.
Join
(
s
.
root
,
filepath
.
FromSlash
(
relative
)),
nil
}
func
(
s
*
SoraMediaStorage
)
downloadAndStore
(
ctx
context
.
Context
,
mediaType
,
rawURL
string
)
(
string
,
error
)
{
if
strings
.
TrimSpace
(
rawURL
)
==
""
{
return
""
,
errors
.
New
(
"empty url"
)
}
root
:=
s
.
imageRoot
if
mediaType
==
"video"
{
root
=
s
.
videoRoot
}
if
root
==
""
{
return
""
,
errors
.
New
(
"storage root not configured"
)
}
retries
:=
3
for
attempt
:=
1
;
attempt
<=
retries
;
attempt
++
{
release
,
err
:=
s
.
acquire
(
ctx
)
if
err
!=
nil
{
return
""
,
err
}
relative
,
err
:=
s
.
downloadOnce
(
ctx
,
root
,
mediaType
,
rawURL
)
release
()
if
err
==
nil
{
return
relative
,
nil
}
if
s
.
debug
{
log
.
Printf
(
"[SoraStorage] 下载失败(%d/%d): %s err=%v"
,
attempt
,
retries
,
sanitizeMediaLogURL
(
rawURL
),
err
)
}
if
attempt
<
retries
{
time
.
Sleep
(
time
.
Duration
(
attempt
*
attempt
)
*
time
.
Second
)
continue
}
return
""
,
err
}
return
""
,
errors
.
New
(
"download retries exhausted"
)
}
func
(
s
*
SoraMediaStorage
)
downloadOnce
(
ctx
context
.
Context
,
root
,
mediaType
,
rawURL
string
)
(
string
,
error
)
{
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
rawURL
,
nil
)
if
err
!=
nil
{
return
""
,
err
}
client
:=
&
http
.
Client
{
Timeout
:
s
.
downloadTimeout
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
return
""
,
fmt
.
Errorf
(
"download failed: %d %s"
,
resp
.
StatusCode
,
string
(
body
))
}
ext
:=
normalizeSoraFileExt
(
fileExtFromURL
(
rawURL
))
if
ext
==
""
{
ext
=
normalizeSoraFileExt
(
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
)))
}
if
ext
==
""
{
ext
=
".bin"
}
if
s
.
maxDownloadBytes
>
0
&&
resp
.
ContentLength
>
s
.
maxDownloadBytes
{
return
""
,
fmt
.
Errorf
(
"download size exceeds limit: %d"
,
resp
.
ContentLength
)
}
storageRoot
,
err
:=
os
.
OpenRoot
(
root
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
storageRoot
.
Close
()
}()
datePath
:=
time
.
Now
()
.
Format
(
"2006/01/02"
)
datePathFS
:=
filepath
.
FromSlash
(
datePath
)
if
err
:=
storageRoot
.
MkdirAll
(
datePathFS
,
0
o755
);
err
!=
nil
{
return
""
,
err
}
filename
:=
uuid
.
NewString
()
+
ext
filePath
:=
filepath
.
Join
(
datePathFS
,
filename
)
out
,
err
:=
storageRoot
.
OpenFile
(
filePath
,
os
.
O_CREATE
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
0
o644
)
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
out
.
Close
()
}()
limited
:=
io
.
LimitReader
(
resp
.
Body
,
s
.
maxDownloadBytes
+
1
)
written
,
err
:=
io
.
Copy
(
out
,
limited
)
if
err
!=
nil
{
removePartialDownload
(
storageRoot
,
filePath
)
return
""
,
err
}
if
s
.
maxDownloadBytes
>
0
&&
written
>
s
.
maxDownloadBytes
{
removePartialDownload
(
storageRoot
,
filePath
)
return
""
,
fmt
.
Errorf
(
"download size exceeds limit: %d"
,
written
)
}
relative
:=
path
.
Join
(
"/"
,
mediaType
,
datePath
,
filename
)
if
s
.
debug
{
log
.
Printf
(
"[SoraStorage] 已落地 %s -> %s"
,
sanitizeMediaLogURL
(
rawURL
),
relative
)
}
return
relative
,
nil
}
func
(
s
*
SoraMediaStorage
)
acquire
(
ctx
context
.
Context
)
(
func
(),
error
)
{
if
s
.
sem
==
nil
{
return
func
()
{},
nil
}
select
{
case
s
.
sem
<-
struct
{}{}
:
return
func
()
{
<-
s
.
sem
},
nil
case
<-
ctx
.
Done
()
:
return
nil
,
ctx
.
Err
()
}
}
func
fileExtFromURL
(
raw
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
""
}
ext
:=
path
.
Ext
(
parsed
.
Path
)
return
strings
.
ToLower
(
ext
)
}
func
fileExtFromContentType
(
ct
string
)
string
{
if
ct
==
""
{
return
""
}
if
exts
,
err
:=
mime
.
ExtensionsByType
(
ct
);
err
==
nil
&&
len
(
exts
)
>
0
{
return
strings
.
ToLower
(
exts
[
0
])
}
return
""
}
func
normalizeSoraFileExt
(
ext
string
)
string
{
ext
=
strings
.
ToLower
(
strings
.
TrimSpace
(
ext
))
switch
ext
{
case
".png"
,
".jpg"
,
".jpeg"
,
".gif"
,
".webp"
,
".bmp"
,
".svg"
,
".tif"
,
".tiff"
,
".heic"
,
".mp4"
,
".mov"
,
".webm"
,
".m4v"
,
".avi"
,
".mkv"
,
".3gp"
,
".flv"
:
return
ext
default
:
return
""
}
}
func
removePartialDownload
(
root
*
os
.
Root
,
filePath
string
)
{
if
root
==
nil
||
strings
.
TrimSpace
(
filePath
)
==
""
{
return
}
_
=
root
.
Remove
(
filePath
)
}
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
func
sanitizeMediaLogURL
(
rawURL
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
rawURL
)
if
err
!=
nil
{
if
len
(
rawURL
)
>
80
{
return
rawURL
[
:
80
]
+
"..."
}
return
rawURL
}
safe
:=
parsed
.
Scheme
+
"://"
+
parsed
.
Host
+
parsed
.
Path
if
len
(
safe
)
>
120
{
return
safe
[
:
120
]
+
"..."
}
return
safe
}
backend/internal/service/sora_media_storage_test.go
deleted
100644 → 0
View file @
f585a15e
//go:build unit
package
service
import
(
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSoraMediaStorage_StoreFromURLs
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"image/png"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"data"
))
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
MaxConcurrentDownloads
:
1
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
urls
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
server
.
URL
+
"/img.png"
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
urls
,
1
)
require
.
True
(
t
,
strings
.
HasPrefix
(
urls
[
0
],
"/image/"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
urls
[
0
],
".png"
))
localPath
:=
filepath
.
Join
(
tmpDir
,
filepath
.
FromSlash
(
strings
.
TrimPrefix
(
urls
[
0
],
"/"
)))
require
.
FileExists
(
t
,
localPath
)
}
func
TestSoraMediaStorage_FallbackToUpstream
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusInternalServerError
)
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
FallbackToUpstream
:
true
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
url
:=
server
.
URL
+
"/broken.png"
urls
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
url
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
url
},
urls
)
}
func
TestSoraMediaStorage_MaxDownloadBytes
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"image/png"
)
w
.
WriteHeader
(
http
.
StatusOK
)
_
,
_
=
w
.
Write
([]
byte
(
"too-large"
))
}))
defer
server
.
Close
()
cfg
:=
&
config
.
Config
{
Sora
:
config
.
SoraConfig
{
Storage
:
config
.
SoraStorageConfig
{
Type
:
"local"
,
LocalPath
:
tmpDir
,
MaxDownloadBytes
:
1
,
},
},
}
storage
:=
NewSoraMediaStorage
(
cfg
)
_
,
err
:=
storage
.
StoreFromURLs
(
context
.
Background
(),
"image"
,
[]
string
{
server
.
URL
+
"/img.png"
})
require
.
Error
(
t
,
err
)
}
func
TestNormalizeSoraFileExt
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
".png"
,
normalizeSoraFileExt
(
".PNG"
))
require
.
Equal
(
t
,
".mp4"
,
normalizeSoraFileExt
(
".mp4"
))
require
.
Equal
(
t
,
""
,
normalizeSoraFileExt
(
"../../etc/passwd"
))
require
.
Equal
(
t
,
""
,
normalizeSoraFileExt
(
".php"
))
}
func
TestRemovePartialDownload
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
root
,
err
:=
os
.
OpenRoot
(
tmpDir
)
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
root
.
Close
()
}()
filePath
:=
"partial.bin"
f
,
err
:=
root
.
OpenFile
(
filePath
,
os
.
O_CREATE
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
0
o600
)
require
.
NoError
(
t
,
err
)
_
,
_
=
f
.
WriteString
(
"partial"
)
_
=
f
.
Close
()
removePartialDownload
(
root
,
filePath
)
_
,
err
=
root
.
Stat
(
filePath
)
require
.
Error
(
t
,
err
)
require
.
True
(
t
,
os
.
IsNotExist
(
err
))
}
backend/internal/service/sora_models.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"regexp"
"sort"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
// SoraModelConfig Sora 模型配置
type
SoraModelConfig
struct
{
Type
string
Width
int
Height
int
Orientation
string
Frames
int
Model
string
Size
string
RequirePro
bool
// Prompt-enhance 专用参数
ExpansionLevel
string
DurationS
int
}
var
soraModelConfigs
=
map
[
string
]
SoraModelConfig
{
"gpt-image"
:
{
Type
:
"image"
,
Width
:
360
,
Height
:
360
,
},
"gpt-image-landscape"
:
{
Type
:
"image"
,
Width
:
540
,
Height
:
360
,
},
"gpt-image-portrait"
:
{
Type
:
"image"
,
Width
:
360
,
Height
:
540
,
},
"sora2-landscape-10s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
300
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-portrait-10s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
300
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-landscape-15s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
450
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-portrait-15s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
450
,
Model
:
"sy_8"
,
Size
:
"small"
,
},
"sora2-landscape-25s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
750
,
Model
:
"sy_8"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2-portrait-25s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
750
,
Model
:
"sy_8"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-landscape-10s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-portrait-10s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-landscape-15s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-portrait-15s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-landscape-25s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
750
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-portrait-25s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
750
,
Model
:
"sy_ore"
,
Size
:
"small"
,
RequirePro
:
true
,
},
"sora2pro-hd-landscape-10s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"sora2pro-hd-portrait-10s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
300
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"sora2pro-hd-landscape-15s"
:
{
Type
:
"video"
,
Orientation
:
"landscape"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"sora2pro-hd-portrait-15s"
:
{
Type
:
"video"
,
Orientation
:
"portrait"
,
Frames
:
450
,
Model
:
"sy_ore"
,
Size
:
"large"
,
RequirePro
:
true
,
},
"prompt-enhance-short-10s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"short"
,
DurationS
:
10
,
},
"prompt-enhance-short-15s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"short"
,
DurationS
:
15
,
},
"prompt-enhance-short-20s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"short"
,
DurationS
:
20
,
},
"prompt-enhance-medium-10s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"medium"
,
DurationS
:
10
,
},
"prompt-enhance-medium-15s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"medium"
,
DurationS
:
15
,
},
"prompt-enhance-medium-20s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"medium"
,
DurationS
:
20
,
},
"prompt-enhance-long-10s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"long"
,
DurationS
:
10
,
},
"prompt-enhance-long-15s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"long"
,
DurationS
:
15
,
},
"prompt-enhance-long-20s"
:
{
Type
:
"prompt_enhance"
,
ExpansionLevel
:
"long"
,
DurationS
:
20
,
},
}
var
soraModelIDs
=
[]
string
{
"gpt-image"
,
"gpt-image-landscape"
,
"gpt-image-portrait"
,
"sora2-landscape-10s"
,
"sora2-portrait-10s"
,
"sora2-landscape-15s"
,
"sora2-portrait-15s"
,
"sora2-landscape-25s"
,
"sora2-portrait-25s"
,
"sora2pro-landscape-10s"
,
"sora2pro-portrait-10s"
,
"sora2pro-landscape-15s"
,
"sora2pro-portrait-15s"
,
"sora2pro-landscape-25s"
,
"sora2pro-portrait-25s"
,
"sora2pro-hd-landscape-10s"
,
"sora2pro-hd-portrait-10s"
,
"sora2pro-hd-landscape-15s"
,
"sora2pro-hd-portrait-15s"
,
"prompt-enhance-short-10s"
,
"prompt-enhance-short-15s"
,
"prompt-enhance-short-20s"
,
"prompt-enhance-medium-10s"
,
"prompt-enhance-medium-15s"
,
"prompt-enhance-medium-20s"
,
"prompt-enhance-long-10s"
,
"prompt-enhance-long-15s"
,
"prompt-enhance-long-20s"
,
}
// GetSoraModelConfig 返回 Sora 模型配置
func
GetSoraModelConfig
(
model
string
)
(
SoraModelConfig
,
bool
)
{
key
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
cfg
,
ok
:=
soraModelConfigs
[
key
]
return
cfg
,
ok
}
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
type
SoraModelFamily
struct
{
ID
string
`json:"id"`
Name
string
`json:"name"`
Type
string
`json:"type"`
Orientations
[]
string
`json:"orientations"`
Durations
[]
int
`json:"durations,omitempty"`
}
var
(
videoSuffixRe
=
regexp
.
MustCompile
(
`-(landscape|portrait)-(\d+)s$`
)
imageSuffixRe
=
regexp
.
MustCompile
(
`-(landscape|portrait)$`
)
soraFamilyNames
=
map
[
string
]
string
{
"sora2"
:
"Sora 2"
,
"sora2pro"
:
"Sora 2 Pro"
,
"sora2pro-hd"
:
"Sora 2 Pro HD"
,
"gpt-image"
:
"GPT Image"
,
}
)
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
func
BuildSoraModelFamilies
()
[]
SoraModelFamily
{
type
familyData
struct
{
modelType
string
orientations
map
[
string
]
bool
durations
map
[
int
]
bool
}
families
:=
make
(
map
[
string
]
*
familyData
)
for
id
,
cfg
:=
range
soraModelConfigs
{
if
cfg
.
Type
==
"prompt_enhance"
{
continue
}
var
famID
,
orientation
string
var
duration
int
switch
cfg
.
Type
{
case
"video"
:
if
m
:=
videoSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
duration
,
_
=
strconv
.
Atoi
(
m
[
2
])
}
case
"image"
:
if
m
:=
imageSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
}
else
{
famID
=
id
orientation
=
"square"
}
}
if
famID
==
""
{
continue
}
fd
,
ok
:=
families
[
famID
]
if
!
ok
{
fd
=
&
familyData
{
modelType
:
cfg
.
Type
,
orientations
:
make
(
map
[
string
]
bool
),
durations
:
make
(
map
[
int
]
bool
),
}
families
[
famID
]
=
fd
}
if
orientation
!=
""
{
fd
.
orientations
[
orientation
]
=
true
}
if
duration
>
0
{
fd
.
durations
[
duration
]
=
true
}
}
// 排序:视频在前、图像在后,同类按名称排序
famIDs
:=
make
([]
string
,
0
,
len
(
families
))
for
id
:=
range
families
{
famIDs
=
append
(
famIDs
,
id
)
}
sort
.
Slice
(
famIDs
,
func
(
i
,
j
int
)
bool
{
fi
,
fj
:=
families
[
famIDs
[
i
]],
families
[
famIDs
[
j
]]
if
fi
.
modelType
!=
fj
.
modelType
{
return
fi
.
modelType
==
"video"
}
return
famIDs
[
i
]
<
famIDs
[
j
]
})
result
:=
make
([]
SoraModelFamily
,
0
,
len
(
famIDs
))
for
_
,
famID
:=
range
famIDs
{
fd
:=
families
[
famID
]
fam
:=
SoraModelFamily
{
ID
:
famID
,
Name
:
soraFamilyNames
[
famID
],
Type
:
fd
.
modelType
,
}
if
fam
.
Name
==
""
{
fam
.
Name
=
famID
}
for
o
:=
range
fd
.
orientations
{
fam
.
Orientations
=
append
(
fam
.
Orientations
,
o
)
}
sort
.
Strings
(
fam
.
Orientations
)
for
d
:=
range
fd
.
durations
{
fam
.
Durations
=
append
(
fam
.
Durations
,
d
)
}
sort
.
Ints
(
fam
.
Durations
)
result
=
append
(
result
,
fam
)
}
return
result
}
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
// 通过命名约定自动识别视频/图像模型并分组。
func
BuildSoraModelFamiliesFromIDs
(
modelIDs
[]
string
)
[]
SoraModelFamily
{
type
familyData
struct
{
modelType
string
orientations
map
[
string
]
bool
durations
map
[
int
]
bool
}
families
:=
make
(
map
[
string
]
*
familyData
)
for
_
,
id
:=
range
modelIDs
{
id
=
strings
.
ToLower
(
strings
.
TrimSpace
(
id
))
if
id
==
""
||
strings
.
HasPrefix
(
id
,
"prompt-enhance"
)
{
continue
}
var
famID
,
orientation
,
modelType
string
var
duration
int
if
m
:=
videoSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
// 视频模型: {family}-{orientation}-{duration}s
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
duration
,
_
=
strconv
.
Atoi
(
m
[
2
])
modelType
=
"video"
}
else
if
m
:=
imageSuffixRe
.
FindStringSubmatch
(
id
);
m
!=
nil
{
// 图像模型(带方向): {family}-{orientation}
famID
=
id
[
:
len
(
id
)
-
len
(
m
[
0
])]
orientation
=
m
[
1
]
modelType
=
"image"
}
else
if
cfg
,
ok
:=
soraModelConfigs
[
id
];
ok
&&
cfg
.
Type
==
"image"
{
// 已知的无后缀图像模型(如 gpt-image)
famID
=
id
orientation
=
"square"
modelType
=
"image"
}
else
if
strings
.
Contains
(
id
,
"image"
)
{
// 未知但名称包含 image 的模型,推断为图像模型
famID
=
id
orientation
=
"square"
modelType
=
"image"
}
else
{
continue
}
if
famID
==
""
{
continue
}
fd
,
ok
:=
families
[
famID
]
if
!
ok
{
fd
=
&
familyData
{
modelType
:
modelType
,
orientations
:
make
(
map
[
string
]
bool
),
durations
:
make
(
map
[
int
]
bool
),
}
families
[
famID
]
=
fd
}
if
orientation
!=
""
{
fd
.
orientations
[
orientation
]
=
true
}
if
duration
>
0
{
fd
.
durations
[
duration
]
=
true
}
}
famIDs
:=
make
([]
string
,
0
,
len
(
families
))
for
id
:=
range
families
{
famIDs
=
append
(
famIDs
,
id
)
}
sort
.
Slice
(
famIDs
,
func
(
i
,
j
int
)
bool
{
fi
,
fj
:=
families
[
famIDs
[
i
]],
families
[
famIDs
[
j
]]
if
fi
.
modelType
!=
fj
.
modelType
{
return
fi
.
modelType
==
"video"
}
return
famIDs
[
i
]
<
famIDs
[
j
]
})
result
:=
make
([]
SoraModelFamily
,
0
,
len
(
famIDs
))
for
_
,
famID
:=
range
famIDs
{
fd
:=
families
[
famID
]
fam
:=
SoraModelFamily
{
ID
:
famID
,
Name
:
soraFamilyNames
[
famID
],
Type
:
fd
.
modelType
,
}
if
fam
.
Name
==
""
{
fam
.
Name
=
famID
}
for
o
:=
range
fd
.
orientations
{
fam
.
Orientations
=
append
(
fam
.
Orientations
,
o
)
}
sort
.
Strings
(
fam
.
Orientations
)
for
d
:=
range
fd
.
durations
{
fam
.
Durations
=
append
(
fam
.
Durations
,
d
)
}
sort
.
Ints
(
fam
.
Durations
)
result
=
append
(
result
,
fam
)
}
return
result
}
// DefaultSoraModels returns the default Sora model list.
func
DefaultSoraModels
(
cfg
*
config
.
Config
)
[]
openai
.
Model
{
models
:=
make
([]
openai
.
Model
,
0
,
len
(
soraModelIDs
))
for
_
,
id
:=
range
soraModelIDs
{
models
=
append
(
models
,
openai
.
Model
{
ID
:
id
,
Object
:
"model"
,
OwnedBy
:
"openai"
,
Type
:
"model"
,
DisplayName
:
id
,
})
}
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SoraModelFilters
.
HidePromptEnhance
{
filtered
:=
models
[
:
0
]
for
_
,
model
:=
range
models
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
model
.
ID
),
"prompt-enhance"
)
{
continue
}
filtered
=
append
(
filtered
,
model
)
}
models
=
filtered
}
return
models
}
backend/internal/service/sora_quota_service.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraQuotaService 管理 Sora 用户存储配额。
// 配额优先级:用户级 → 分组级 → 系统默认值。
type
SoraQuotaService
struct
{
userRepo
UserRepository
groupRepo
GroupRepository
settingService
*
SettingService
}
// NewSoraQuotaService 创建配额服务实例。
func
NewSoraQuotaService
(
userRepo
UserRepository
,
groupRepo
GroupRepository
,
settingService
*
SettingService
,
)
*
SoraQuotaService
{
return
&
SoraQuotaService
{
userRepo
:
userRepo
,
groupRepo
:
groupRepo
,
settingService
:
settingService
,
}
}
// QuotaInfo 返回给客户端的配额信息。
type
QuotaInfo
struct
{
QuotaBytes
int64
`json:"quota_bytes"`
// 总配额(0 表示无限制)
UsedBytes
int64
`json:"used_bytes"`
// 已使用
AvailableBytes
int64
`json:"available_bytes"`
// 剩余可用(无限制时为 0)
QuotaSource
string
`json:"quota_source"`
// 配额来源:user / group / system / unlimited
Source
string
`json:"source,omitempty"`
// 兼容旧字段
}
// ErrSoraStorageQuotaExceeded 表示配额不足。
var
ErrSoraStorageQuotaExceeded
=
errors
.
New
(
"sora storage quota exceeded"
)
// QuotaExceededError 包含配额不足的上下文信息。
type
QuotaExceededError
struct
{
QuotaBytes
int64
UsedBytes
int64
}
func
(
e
*
QuotaExceededError
)
Error
()
string
{
if
e
==
nil
{
return
"存储配额不足"
}
return
fmt
.
Sprintf
(
"存储配额不足(已用 %d / 配额 %d 字节)"
,
e
.
UsedBytes
,
e
.
QuotaBytes
)
}
type
soraQuotaAtomicUserRepository
interface
{
AddSoraStorageUsageWithQuota
(
ctx
context
.
Context
,
userID
int64
,
deltaBytes
int64
,
effectiveQuota
int64
)
(
int64
,
error
)
ReleaseSoraStorageUsageAtomic
(
ctx
context
.
Context
,
userID
int64
,
deltaBytes
int64
)
(
int64
,
error
)
}
// GetQuota 获取用户的存储配额信息。
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
func
(
s
*
SoraQuotaService
)
GetQuota
(
ctx
context
.
Context
,
userID
int64
)
(
*
QuotaInfo
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
info
:=
&
QuotaInfo
{
UsedBytes
:
user
.
SoraStorageUsedBytes
,
}
// 1. 用户级配额
if
user
.
SoraStorageQuotaBytes
>
0
{
info
.
QuotaBytes
=
user
.
SoraStorageQuotaBytes
info
.
QuotaSource
=
"user"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
calcAvailableBytes
(
info
.
QuotaBytes
,
info
.
UsedBytes
)
return
info
,
nil
}
// 2. 分组级配额(取用户可用分组中最大的配额)
if
len
(
user
.
AllowedGroups
)
>
0
{
var
maxGroupQuota
int64
for
_
,
gid
:=
range
user
.
AllowedGroups
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
gid
)
if
err
!=
nil
{
continue
}
if
group
.
SoraStorageQuotaBytes
>
maxGroupQuota
{
maxGroupQuota
=
group
.
SoraStorageQuotaBytes
}
}
if
maxGroupQuota
>
0
{
info
.
QuotaBytes
=
maxGroupQuota
info
.
QuotaSource
=
"group"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
calcAvailableBytes
(
info
.
QuotaBytes
,
info
.
UsedBytes
)
return
info
,
nil
}
}
// 3. 系统默认值
defaultQuota
:=
s
.
getSystemDefaultQuota
(
ctx
)
if
defaultQuota
>
0
{
info
.
QuotaBytes
=
defaultQuota
info
.
QuotaSource
=
"system"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
calcAvailableBytes
(
info
.
QuotaBytes
,
info
.
UsedBytes
)
return
info
,
nil
}
// 无配额限制
info
.
QuotaSource
=
"unlimited"
info
.
Source
=
info
.
QuotaSource
info
.
AvailableBytes
=
0
return
info
,
nil
}
// CheckQuota 检查用户是否有足够的存储配额。
// 返回 nil 表示配额充足或无限制。
func
(
s
*
SoraQuotaService
)
CheckQuota
(
ctx
context
.
Context
,
userID
int64
,
additionalBytes
int64
)
error
{
quota
,
err
:=
s
.
GetQuota
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
// 0 表示无限制
if
quota
.
QuotaBytes
==
0
{
return
nil
}
if
quota
.
UsedBytes
+
additionalBytes
>
quota
.
QuotaBytes
{
return
&
QuotaExceededError
{
QuotaBytes
:
quota
.
QuotaBytes
,
UsedBytes
:
quota
.
UsedBytes
,
}
}
return
nil
}
// AddUsage 原子累加用量(上传成功后调用)。
func
(
s
*
SoraQuotaService
)
AddUsage
(
ctx
context
.
Context
,
userID
int64
,
bytes
int64
)
error
{
if
bytes
<=
0
{
return
nil
}
quota
,
err
:=
s
.
GetQuota
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
if
quota
.
QuotaBytes
>
0
&&
quota
.
UsedBytes
+
bytes
>
quota
.
QuotaBytes
{
return
&
QuotaExceededError
{
QuotaBytes
:
quota
.
QuotaBytes
,
UsedBytes
:
quota
.
UsedBytes
,
}
}
if
repo
,
ok
:=
s
.
userRepo
.
(
soraQuotaAtomicUserRepository
);
ok
{
newUsed
,
err
:=
repo
.
AddSoraStorageUsageWithQuota
(
ctx
,
userID
,
bytes
,
quota
.
QuotaBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSoraStorageQuotaExceeded
)
{
return
&
QuotaExceededError
{
QuotaBytes
:
quota
.
QuotaBytes
,
UsedBytes
:
quota
.
UsedBytes
,
}
}
return
fmt
.
Errorf
(
"update user quota usage (atomic): %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 累加用量 user=%d +%d total=%d"
,
userID
,
bytes
,
newUsed
)
return
nil
}
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user for quota update: %w"
,
err
)
}
user
.
SoraStorageUsedBytes
+=
bytes
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update user quota usage: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 累加用量 user=%d +%d total=%d"
,
userID
,
bytes
,
user
.
SoraStorageUsedBytes
)
return
nil
}
// ReleaseUsage 释放用量(删除文件后调用)。
func
(
s
*
SoraQuotaService
)
ReleaseUsage
(
ctx
context
.
Context
,
userID
int64
,
bytes
int64
)
error
{
if
bytes
<=
0
{
return
nil
}
if
repo
,
ok
:=
s
.
userRepo
.
(
soraQuotaAtomicUserRepository
);
ok
{
newUsed
,
err
:=
repo
.
ReleaseSoraStorageUsageAtomic
(
ctx
,
userID
,
bytes
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"update user quota release (atomic): %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 释放用量 user=%d -%d total=%d"
,
userID
,
bytes
,
newUsed
)
return
nil
}
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user for quota release: %w"
,
err
)
}
user
.
SoraStorageUsedBytes
-=
bytes
if
user
.
SoraStorageUsedBytes
<
0
{
user
.
SoraStorageUsedBytes
=
0
}
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update user quota release: %w"
,
err
)
}
logger
.
LegacyPrintf
(
"service.sora_quota"
,
"[SoraQuota] 释放用量 user=%d -%d total=%d"
,
userID
,
bytes
,
user
.
SoraStorageUsedBytes
)
return
nil
}
func
calcAvailableBytes
(
quotaBytes
,
usedBytes
int64
)
int64
{
if
quotaBytes
<=
0
{
return
0
}
if
usedBytes
>=
quotaBytes
{
return
0
}
return
quotaBytes
-
usedBytes
}
func
(
s
*
SoraQuotaService
)
getSystemDefaultQuota
(
ctx
context
.
Context
)
int64
{
if
s
.
settingService
==
nil
{
return
0
}
settings
,
err
:=
s
.
settingService
.
GetSoraS3Settings
(
ctx
)
if
err
!=
nil
{
return
0
}
return
settings
.
DefaultStorageQuotaBytes
}
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
func
(
s
*
SoraQuotaService
)
GetQuotaFromSettings
(
ctx
context
.
Context
)
int64
{
return
s
.
getSystemDefaultQuota
(
ctx
)
}
// SetUserQuota 设置用户级配额(管理员操作)。
func
SetUserSoraQuota
(
ctx
context
.
Context
,
userRepo
UserRepository
,
userID
int64
,
quotaBytes
int64
)
error
{
user
,
err
:=
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
user
.
SoraStorageQuotaBytes
=
quotaBytes
return
userRepo
.
Update
(
ctx
,
user
)
}
// ParseQuotaBytes 解析配额字符串为字节数。
func
ParseQuotaBytes
(
s
string
)
int64
{
v
,
_
:=
strconv
.
ParseInt
(
s
,
10
,
64
)
return
v
}
backend/internal/service/sora_quota_service_test.go
deleted
100644 → 0
View file @
f585a15e
//go:build unit
package
service
import
(
"context"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ==================== Stub: GroupRepository (用于 SoraQuotaService) ====================
var
_
GroupRepository
=
(
*
stubGroupRepoForQuota
)(
nil
)
type
stubGroupRepoForQuota
struct
{
groups
map
[
int64
]
*
Group
}
func
newStubGroupRepoForQuota
()
*
stubGroupRepoForQuota
{
return
&
stubGroupRepoForQuota
{
groups
:
make
(
map
[
int64
]
*
Group
)}
}
func
(
r
*
stubGroupRepoForQuota
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
r
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
fmt
.
Errorf
(
"group not found"
)
}
func
(
r
*
stubGroupRepoForQuota
)
Create
(
context
.
Context
,
*
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
r
.
GetByID
(
context
.
Background
(),
id
)
}
func
(
r
*
stubGroupRepoForQuota
)
Update
(
context
.
Context
,
*
Group
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
DeleteCascade
(
context
.
Context
,
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
string
,
string
,
string
,
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ListActive
(
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ListActiveByPlatform
(
context
.
Context
,
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
ExistsByName
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
GetAccountCount
(
context
.
Context
,
int64
)
(
int64
,
int64
,
error
)
{
return
0
,
0
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
DeleteAccountGroupsByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
GetAccountIDsByGroupIDs
(
context
.
Context
,
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
r
*
stubGroupRepoForQuota
)
BindAccountsToGroup
(
context
.
Context
,
int64
,
[]
int64
)
error
{
return
nil
}
func
(
r
*
stubGroupRepoForQuota
)
UpdateSortOrders
(
context
.
Context
,
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
// ==================== Stub: SettingRepository (用于 SettingService) ====================
var
_
SettingRepository
=
(
*
stubSettingRepoForQuota
)(
nil
)
type
stubSettingRepoForQuota
struct
{
values
map
[
string
]
string
}
func
newStubSettingRepoForQuota
(
values
map
[
string
]
string
)
*
stubSettingRepoForQuota
{
if
values
==
nil
{
values
=
make
(
map
[
string
]
string
)
}
return
&
stubSettingRepoForQuota
{
values
:
values
}
}
func
(
r
*
stubSettingRepoForQuota
)
Get
(
_
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
&
Setting
{
Key
:
key
,
Value
:
v
},
nil
}
return
nil
,
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForQuota
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
if
v
,
ok
:=
r
.
values
[
key
];
ok
{
return
v
,
nil
}
return
""
,
ErrSettingNotFound
}
func
(
r
*
stubSettingRepoForQuota
)
Set
(
_
context
.
Context
,
key
,
value
string
)
error
{
r
.
values
[
key
]
=
value
return
nil
}
func
(
r
*
stubSettingRepoForQuota
)
GetMultiple
(
_
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
result
:=
make
(
map
[
string
]
string
)
for
_
,
k
:=
range
keys
{
if
v
,
ok
:=
r
.
values
[
k
];
ok
{
result
[
k
]
=
v
}
}
return
result
,
nil
}
func
(
r
*
stubSettingRepoForQuota
)
SetMultiple
(
_
context
.
Context
,
settings
map
[
string
]
string
)
error
{
for
k
,
v
:=
range
settings
{
r
.
values
[
k
]
=
v
}
return
nil
}
func
(
r
*
stubSettingRepoForQuota
)
GetAll
(
_
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
return
r
.
values
,
nil
}
func
(
r
*
stubSettingRepoForQuota
)
Delete
(
_
context
.
Context
,
key
string
)
error
{
delete
(
r
.
values
,
key
)
return
nil
}
// ==================== GetQuota ====================
func
TestGetQuota_UserLevel
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
// 10MB
SoraStorageUsedBytes
:
3
*
1024
*
1024
,
// 3MB
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
*
1024
*
1024
),
quota
.
QuotaBytes
)
require
.
Equal
(
t
,
int64
(
3
*
1024
*
1024
),
quota
.
UsedBytes
)
require
.
Equal
(
t
,
"user"
,
quota
.
Source
)
}
func
TestGetQuota_GroupLevel
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
// 用户级无配额
SoraStorageUsedBytes
:
1024
,
AllowedGroups
:
[]
int64
{
10
,
20
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
5
*
1024
*
1024
}
groupRepo
.
groups
[
20
]
=
&
Group
{
ID
:
20
,
SoraStorageQuotaBytes
:
20
*
1024
*
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
20
*
1024
*
1024
),
quota
.
QuotaBytes
)
// 取最大值
require
.
Equal
(
t
,
"group"
,
quota
.
Source
)
}
func
TestGetQuota_SystemLevel
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
SoraStorageUsedBytes
:
512
}
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"104857600"
,
// 100MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
settingService
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
104857600
),
quota
.
QuotaBytes
)
require
.
Equal
(
t
,
"system"
,
quota
.
Source
)
}
func
TestGetQuota_NoLimit
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
SoraStorageUsedBytes
:
0
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
quota
.
QuotaBytes
)
require
.
Equal
(
t
,
"unlimited"
,
quota
.
Source
)
}
func
TestGetQuota_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
_
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
999
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"get user"
)
}
func
TestGetQuota_GroupRepoError
(
t
*
testing
.
T
)
{
// 分组获取失败时跳过该分组(不影响整体)
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
AllowedGroups
:
[]
int64
{
999
},
// 不存在的分组
}
groupRepo
:=
newStubGroupRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"unlimited"
,
quota
.
Source
)
// 分组获取失败,回退到无限制
}
// ==================== CheckQuota ====================
func
TestCheckQuota_Sufficient
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
3
*
1024
*
1024
,
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
CheckQuota
(
context
.
Background
(),
1
,
1024
)
require
.
NoError
(
t
,
err
)
}
func
TestCheckQuota_Exceeded
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
10
*
1024
*
1024
,
SoraStorageUsedBytes
:
10
*
1024
*
1024
,
// 已满
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
CheckQuota
(
context
.
Background
(),
1
,
1
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"配额不足"
)
}
func
TestCheckQuota_NoLimit
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
// 无限制
SoraStorageUsedBytes
:
1000000000
,
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
CheckQuota
(
context
.
Background
(),
1
,
999999999
)
require
.
NoError
(
t
,
err
)
// 无限制时始终通过
}
func
TestCheckQuota_ExactBoundary
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
1024
,
SoraStorageUsedBytes
:
1024
,
// 恰好满
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
// 额外 0 字节不超
require
.
NoError
(
t
,
svc
.
CheckQuota
(
context
.
Background
(),
1
,
0
))
// 额外 1 字节超出
require
.
Error
(
t
,
svc
.
CheckQuota
(
context
.
Background
(),
1
,
1
))
}
// ==================== AddUsage ====================
func
TestAddUsage_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
2048
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3072
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestAddUsage_ZeroBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
0
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestAddUsage_NegativeBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
-
100
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestAddUsage_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
999
,
1024
)
require
.
Error
(
t
,
err
)
}
func
TestAddUsage_UpdateError
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
0
}
userRepo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
AddUsage
(
context
.
Background
(),
1
,
1024
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"update user quota usage"
)
}
// ==================== ReleaseUsage ====================
func
TestReleaseUsage_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
3072
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2048
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestReleaseUsage_ClampToZero
(
t
*
testing
.
T
)
{
// 释放量大于已用量时,应 clamp 到 0
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
500
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
1000
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
}
func
TestReleaseUsage_ZeroBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
0
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestReleaseUsage_NegativeBytes
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
-
50
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1024
),
userRepo
.
users
[
1
]
.
SoraStorageUsedBytes
)
// 不变
}
func
TestReleaseUsage_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
999
,
1024
)
require
.
Error
(
t
,
err
)
}
func
TestReleaseUsage_UpdateError
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageUsedBytes
:
1024
}
userRepo
.
updateErr
=
fmt
.
Errorf
(
"db error"
)
svc
:=
NewSoraQuotaService
(
userRepo
,
nil
,
nil
)
err
:=
svc
.
ReleaseUsage
(
context
.
Background
(),
1
,
512
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"update user quota release"
)
}
// ==================== GetQuotaFromSettings ====================
func
TestGetQuotaFromSettings_NilSettingService
(
t
*
testing
.
T
)
{
svc
:=
NewSoraQuotaService
(
nil
,
nil
,
nil
)
require
.
Equal
(
t
,
int64
(
0
),
svc
.
GetQuotaFromSettings
(
context
.
Background
()))
}
func
TestGetQuotaFromSettings_WithSettings
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"52428800"
,
// 50MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
nil
,
nil
,
settingService
)
require
.
Equal
(
t
,
int64
(
52428800
),
svc
.
GetQuotaFromSettings
(
context
.
Background
()))
}
// ==================== SetUserSoraQuota ====================
func
TestSetUserSoraQuota_Success
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
}
err
:=
SetUserSoraQuota
(
context
.
Background
(),
userRepo
,
1
,
10
*
1024
*
1024
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
*
1024
*
1024
),
userRepo
.
users
[
1
]
.
SoraStorageQuotaBytes
)
}
func
TestSetUserSoraQuota_UserNotFound
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
err
:=
SetUserSoraQuota
(
context
.
Background
(),
userRepo
,
999
,
1024
)
require
.
Error
(
t
,
err
)
}
// ==================== ParseQuotaBytes ====================
func
TestParseQuotaBytes
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
int64
(
1048576
),
ParseQuotaBytes
(
"1048576"
))
require
.
Equal
(
t
,
int64
(
0
),
ParseQuotaBytes
(
""
))
require
.
Equal
(
t
,
int64
(
0
),
ParseQuotaBytes
(
"abc"
))
require
.
Equal
(
t
,
int64
(
-
1
),
ParseQuotaBytes
(
"-1"
))
}
// ==================== 优先级完整测试 ====================
func
TestQuotaPriority_UserOverridesGroup
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
5
*
1024
*
1024
,
AllowedGroups
:
[]
int64
{
10
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
20
*
1024
*
1024
}
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
nil
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"user"
,
quota
.
Source
)
// 用户级优先
require
.
Equal
(
t
,
int64
(
5
*
1024
*
1024
),
quota
.
QuotaBytes
)
}
func
TestQuotaPriority_GroupOverridesSystem
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
AllowedGroups
:
[]
int64
{
10
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
20
*
1024
*
1024
}
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"104857600"
,
// 100MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
settingService
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"group"
,
quota
.
Source
)
// 分组级优先于系统
require
.
Equal
(
t
,
int64
(
20
*
1024
*
1024
),
quota
.
QuotaBytes
)
}
func
TestQuotaPriority_FallbackToSystem
(
t
*
testing
.
T
)
{
userRepo
:=
newStubUserRepoForQuota
()
userRepo
.
users
[
1
]
=
&
User
{
ID
:
1
,
SoraStorageQuotaBytes
:
0
,
AllowedGroups
:
[]
int64
{
10
},
}
groupRepo
:=
newStubGroupRepoForQuota
()
groupRepo
.
groups
[
10
]
=
&
Group
{
ID
:
10
,
SoraStorageQuotaBytes
:
0
}
// 分组无配额
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraDefaultStorageQuotaBytes
:
"52428800"
,
// 50MB
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
svc
:=
NewSoraQuotaService
(
userRepo
,
groupRepo
,
settingService
)
quota
,
err
:=
svc
.
GetQuota
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"system"
,
quota
.
Source
)
require
.
Equal
(
t
,
int64
(
52428800
),
quota
.
QuotaBytes
)
}
backend/internal/service/sora_s3_storage.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"context"
"fmt"
"io"
"net/http"
"path"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
type
SoraS3Storage
struct
{
settingService
*
SettingService
mu
sync
.
RWMutex
client
*
s3
.
Client
cfg
*
SoraS3Settings
// 上次加载的配置快照
healthCheckedAt
time
.
Time
healthErr
error
healthTTL
time
.
Duration
}
const
defaultSoraS3HealthTTL
=
30
*
time
.
Second
// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
type
UpstreamDownloadError
struct
{
StatusCode
int
}
func
(
e
*
UpstreamDownloadError
)
Error
()
string
{
if
e
==
nil
{
return
"upstream download failed"
}
return
fmt
.
Sprintf
(
"upstream returned %d"
,
e
.
StatusCode
)
}
// NewSoraS3Storage 创建 S3 存储服务实例。
func
NewSoraS3Storage
(
settingService
*
SettingService
)
*
SoraS3Storage
{
return
&
SoraS3Storage
{
settingService
:
settingService
,
healthTTL
:
defaultSoraS3HealthTTL
,
}
}
// Enabled 返回 S3 存储是否已启用且配置有效。
func
(
s
*
SoraS3Storage
)
Enabled
(
ctx
context
.
Context
)
bool
{
cfg
,
err
:=
s
.
getConfig
(
ctx
)
if
err
!=
nil
||
cfg
==
nil
{
return
false
}
return
cfg
.
Enabled
&&
cfg
.
Bucket
!=
""
}
// getConfig 获取当前 S3 配置(从 settings 表读取)。
func
(
s
*
SoraS3Storage
)
getConfig
(
ctx
context
.
Context
)
(
*
SoraS3Settings
,
error
)
{
if
s
.
settingService
==
nil
{
return
nil
,
fmt
.
Errorf
(
"setting service not available"
)
}
return
s
.
settingService
.
GetSoraS3Settings
(
ctx
)
}
// getClient 获取或初始化 S3 客户端(带缓存)。
// 配置变更时调用 RefreshClient 清除缓存。
func
(
s
*
SoraS3Storage
)
getClient
(
ctx
context
.
Context
)
(
*
s3
.
Client
,
*
SoraS3Settings
,
error
)
{
s
.
mu
.
RLock
()
if
s
.
client
!=
nil
&&
s
.
cfg
!=
nil
{
client
,
cfg
:=
s
.
client
,
s
.
cfg
s
.
mu
.
RUnlock
()
return
client
,
cfg
,
nil
}
s
.
mu
.
RUnlock
()
return
s
.
initClient
(
ctx
)
}
func
(
s
*
SoraS3Storage
)
initClient
(
ctx
context
.
Context
)
(
*
s3
.
Client
,
*
SoraS3Settings
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
// 双重检查
if
s
.
client
!=
nil
&&
s
.
cfg
!=
nil
{
return
s
.
client
,
s
.
cfg
,
nil
}
cfg
,
err
:=
s
.
getConfig
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"load s3 config: %w"
,
err
)
}
if
!
cfg
.
Enabled
{
return
nil
,
nil
,
fmt
.
Errorf
(
"sora s3 storage is disabled"
)
}
if
cfg
.
Bucket
==
""
||
cfg
.
AccessKeyID
==
""
||
cfg
.
SecretAccessKey
==
""
{
return
nil
,
nil
,
fmt
.
Errorf
(
"sora s3 config incomplete: bucket, access_key_id, secret_access_key are required"
)
}
client
,
region
,
err
:=
buildSoraS3Client
(
ctx
,
cfg
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
s
.
client
=
client
s
.
cfg
=
cfg
logger
.
LegacyPrintf
(
"service.sora_s3"
,
"[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s"
,
cfg
.
Bucket
,
cfg
.
Endpoint
,
region
)
return
client
,
cfg
,
nil
}
// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
// 应在系统设置中 S3 配置变更时调用。
func
(
s
*
SoraS3Storage
)
RefreshClient
()
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
client
=
nil
s
.
cfg
=
nil
s
.
healthCheckedAt
=
time
.
Time
{}
s
.
healthErr
=
nil
logger
.
LegacyPrintf
(
"service.sora_s3"
,
"[SoraS3] 客户端缓存已清除,下次使用将重新初始化"
)
}
// TestConnection 测试 S3 连接(HeadBucket)。
func
(
s
*
SoraS3Storage
)
TestConnection
(
ctx
context
.
Context
)
error
{
client
,
cfg
,
err
:=
s
.
getClient
(
ctx
)
if
err
!=
nil
{
return
err
}
_
,
err
=
client
.
HeadBucket
(
ctx
,
&
s3
.
HeadBucketInput
{
Bucket
:
&
cfg
.
Bucket
,
})
if
err
!=
nil
{
return
fmt
.
Errorf
(
"s3 HeadBucket failed: %w"
,
err
)
}
return
nil
}
// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。
func
(
s
*
SoraS3Storage
)
IsHealthy
(
ctx
context
.
Context
)
bool
{
if
s
==
nil
{
return
false
}
now
:=
time
.
Now
()
s
.
mu
.
RLock
()
lastCheck
:=
s
.
healthCheckedAt
lastErr
:=
s
.
healthErr
ttl
:=
s
.
healthTTL
s
.
mu
.
RUnlock
()
if
ttl
<=
0
{
ttl
=
defaultSoraS3HealthTTL
}
if
!
lastCheck
.
IsZero
()
&&
now
.
Sub
(
lastCheck
)
<
ttl
{
return
lastErr
==
nil
}
err
:=
s
.
TestConnection
(
ctx
)
s
.
mu
.
Lock
()
s
.
healthCheckedAt
=
time
.
Now
()
s
.
healthErr
=
err
s
.
mu
.
Unlock
()
return
err
==
nil
}
// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
func
(
s
*
SoraS3Storage
)
TestConnectionWithSettings
(
ctx
context
.
Context
,
cfg
*
SoraS3Settings
)
error
{
if
cfg
==
nil
{
return
fmt
.
Errorf
(
"s3 config is required"
)
}
if
!
cfg
.
Enabled
{
return
fmt
.
Errorf
(
"sora s3 storage is disabled"
)
}
if
cfg
.
Endpoint
==
""
||
cfg
.
Bucket
==
""
||
cfg
.
AccessKeyID
==
""
||
cfg
.
SecretAccessKey
==
""
{
return
fmt
.
Errorf
(
"sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required"
)
}
client
,
_
,
err
:=
buildSoraS3Client
(
ctx
,
cfg
)
if
err
!=
nil
{
return
err
}
_
,
err
=
client
.
HeadBucket
(
ctx
,
&
s3
.
HeadBucketInput
{
Bucket
:
&
cfg
.
Bucket
,
})
if
err
!=
nil
{
return
fmt
.
Errorf
(
"s3 HeadBucket failed: %w"
,
err
)
}
return
nil
}
// GenerateObjectKey 生成 S3 object key。
// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
func
(
s
*
SoraS3Storage
)
GenerateObjectKey
(
prefix
string
,
userID
int64
,
ext
string
)
string
{
if
!
strings
.
HasPrefix
(
ext
,
"."
)
{
ext
=
"."
+
ext
}
datePath
:=
time
.
Now
()
.
Format
(
"2006/01/02"
)
key
:=
fmt
.
Sprintf
(
"sora/%d/%s/%s%s"
,
userID
,
datePath
,
uuid
.
NewString
(),
ext
)
if
prefix
!=
""
{
prefix
=
strings
.
TrimRight
(
prefix
,
"/"
)
+
"/"
key
=
prefix
+
key
}
return
key
}
// UploadFromURL 从上游 URL 下载并流式上传到 S3。
// 返回 S3 object key。
func
(
s
*
SoraS3Storage
)
UploadFromURL
(
ctx
context
.
Context
,
userID
int64
,
sourceURL
string
)
(
string
,
int64
,
error
)
{
client
,
cfg
,
err
:=
s
.
getClient
(
ctx
)
if
err
!=
nil
{
return
""
,
0
,
err
}
// 下载源文件
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
sourceURL
,
nil
)
if
err
!=
nil
{
return
""
,
0
,
fmt
.
Errorf
(
"create download request: %w"
,
err
)
}
httpClient
:=
&
http
.
Client
{
Timeout
:
5
*
time
.
Minute
}
resp
,
err
:=
httpClient
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
0
,
fmt
.
Errorf
(
"download from upstream: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
""
,
0
,
&
UpstreamDownloadError
{
StatusCode
:
resp
.
StatusCode
}
}
// 推断文件扩展名
ext
:=
fileExtFromURL
(
sourceURL
)
if
ext
==
""
{
ext
=
fileExtFromContentType
(
resp
.
Header
.
Get
(
"Content-Type"
))
}
if
ext
==
""
{
ext
=
".bin"
}
objectKey
:=
s
.
GenerateObjectKey
(
cfg
.
Prefix
,
userID
,
ext
)
// 检测 Content-Type
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/octet-stream"
}
reader
,
writer
:=
io
.
Pipe
()
uploadErrCh
:=
make
(
chan
error
,
1
)
go
func
()
{
defer
close
(
uploadErrCh
)
input
:=
&
s3
.
PutObjectInput
{
Bucket
:
&
cfg
.
Bucket
,
Key
:
&
objectKey
,
Body
:
reader
,
ContentType
:
&
contentType
,
}
if
resp
.
ContentLength
>=
0
{
input
.
ContentLength
=
&
resp
.
ContentLength
}
_
,
uploadErr
:=
client
.
PutObject
(
ctx
,
input
)
uploadErrCh
<-
uploadErr
}()
written
,
copyErr
:=
io
.
CopyBuffer
(
writer
,
resp
.
Body
,
make
([]
byte
,
1024
*
1024
))
_
=
writer
.
CloseWithError
(
copyErr
)
uploadErr
:=
<-
uploadErrCh
if
copyErr
!=
nil
{
return
""
,
0
,
fmt
.
Errorf
(
"stream upload copy failed: %w"
,
copyErr
)
}
if
uploadErr
!=
nil
{
return
""
,
0
,
fmt
.
Errorf
(
"s3 upload: %w"
,
uploadErr
)
}
logger
.
LegacyPrintf
(
"service.sora_s3"
,
"[SoraS3] 上传完成 key=%s size=%d"
,
objectKey
,
written
)
return
objectKey
,
written
,
nil
}
func
buildSoraS3Client
(
ctx
context
.
Context
,
cfg
*
SoraS3Settings
)
(
*
s3
.
Client
,
string
,
error
)
{
if
cfg
==
nil
{
return
nil
,
""
,
fmt
.
Errorf
(
"s3 config is required"
)
}
region
:=
cfg
.
Region
if
region
==
""
{
region
=
"us-east-1"
}
awsCfg
,
err
:=
awsconfig
.
LoadDefaultConfig
(
ctx
,
awsconfig
.
WithRegion
(
region
),
awsconfig
.
WithCredentialsProvider
(
credentials
.
NewStaticCredentialsProvider
(
cfg
.
AccessKeyID
,
cfg
.
SecretAccessKey
,
""
),
),
)
if
err
!=
nil
{
return
nil
,
""
,
fmt
.
Errorf
(
"load aws config: %w"
,
err
)
}
client
:=
s3
.
NewFromConfig
(
awsCfg
,
func
(
o
*
s3
.
Options
)
{
if
cfg
.
Endpoint
!=
""
{
o
.
BaseEndpoint
=
&
cfg
.
Endpoint
}
if
cfg
.
ForcePathStyle
{
o
.
UsePathStyle
=
true
}
o
.
APIOptions
=
append
(
o
.
APIOptions
,
v4
.
SwapComputePayloadSHA256ForUnsignedPayloadMiddleware
)
// 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败
o
.
RequestChecksumCalculation
=
aws
.
RequestChecksumCalculationWhenRequired
})
return
client
,
region
,
nil
}
// DeleteObjects 删除一组 S3 object(遍历逐一删除)。
func
(
s
*
SoraS3Storage
)
DeleteObjects
(
ctx
context
.
Context
,
objectKeys
[]
string
)
error
{
if
len
(
objectKeys
)
==
0
{
return
nil
}
client
,
cfg
,
err
:=
s
.
getClient
(
ctx
)
if
err
!=
nil
{
return
err
}
var
lastErr
error
for
_
,
key
:=
range
objectKeys
{
k
:=
key
_
,
err
:=
client
.
DeleteObject
(
ctx
,
&
s3
.
DeleteObjectInput
{
Bucket
:
&
cfg
.
Bucket
,
Key
:
&
k
,
})
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.sora_s3"
,
"[SoraS3] 删除失败 key=%s err=%v"
,
key
,
err
)
lastErr
=
err
}
}
return
lastErr
}
// GetAccessURL 获取 S3 文件的访问 URL。
// CDN URL 优先,否则生成 24h 预签名 URL。
func
(
s
*
SoraS3Storage
)
GetAccessURL
(
ctx
context
.
Context
,
objectKey
string
)
(
string
,
error
)
{
_
,
cfg
,
err
:=
s
.
getClient
(
ctx
)
if
err
!=
nil
{
return
""
,
err
}
// CDN URL 优先
if
cfg
.
CDNURL
!=
""
{
cdnBase
:=
strings
.
TrimRight
(
cfg
.
CDNURL
,
"/"
)
return
cdnBase
+
"/"
+
objectKey
,
nil
}
// 生成 24h 预签名 URL
return
s
.
GeneratePresignedURL
(
ctx
,
objectKey
,
24
*
time
.
Hour
)
}
// GeneratePresignedURL 生成预签名 URL。
func
(
s
*
SoraS3Storage
)
GeneratePresignedURL
(
ctx
context
.
Context
,
objectKey
string
,
ttl
time
.
Duration
)
(
string
,
error
)
{
client
,
cfg
,
err
:=
s
.
getClient
(
ctx
)
if
err
!=
nil
{
return
""
,
err
}
presignClient
:=
s3
.
NewPresignClient
(
client
)
result
,
err
:=
presignClient
.
PresignGetObject
(
ctx
,
&
s3
.
GetObjectInput
{
Bucket
:
&
cfg
.
Bucket
,
Key
:
&
objectKey
,
},
s3
.
WithPresignExpires
(
ttl
))
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"presign url: %w"
,
err
)
}
return
result
.
URL
,
nil
}
// GetMediaType 从 object key 推断媒体类型(image/video)。
func
GetMediaTypeFromKey
(
objectKey
string
)
string
{
ext
:=
strings
.
ToLower
(
path
.
Ext
(
objectKey
))
switch
ext
{
case
".mp4"
,
".mov"
,
".webm"
,
".m4v"
,
".avi"
,
".mkv"
,
".3gp"
,
".flv"
:
return
"video"
default
:
return
"image"
}
}
backend/internal/service/sora_s3_storage_test.go
deleted
100644 → 0
View file @
f585a15e
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ==================== RefreshClient ====================
func
TestRefreshClient
(
t
*
testing
.
T
)
{
s
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
require
.
NotNil
(
t
,
s
.
client
)
require
.
NotNil
(
t
,
s
.
cfg
)
s
.
RefreshClient
()
require
.
Nil
(
t
,
s
.
client
)
require
.
Nil
(
t
,
s
.
cfg
)
}
func
TestRefreshClient_AlreadyNil
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
s
.
RefreshClient
()
// 不应 panic
require
.
Nil
(
t
,
s
.
client
)
require
.
Nil
(
t
,
s
.
cfg
)
}
// ==================== GetMediaTypeFromKey ====================
func
TestGetMediaTypeFromKey_VideoExtensions
(
t
*
testing
.
T
)
{
for
_
,
ext
:=
range
[]
string
{
".mp4"
,
".mov"
,
".webm"
,
".m4v"
,
".avi"
,
".mkv"
,
".3gp"
,
".flv"
}
{
require
.
Equal
(
t
,
"video"
,
GetMediaTypeFromKey
(
"path/to/file"
+
ext
),
"ext=%s"
,
ext
)
}
}
func
TestGetMediaTypeFromKey_VideoUpperCase
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"video"
,
GetMediaTypeFromKey
(
"file.MP4"
))
require
.
Equal
(
t
,
"video"
,
GetMediaTypeFromKey
(
"file.MOV"
))
}
func
TestGetMediaTypeFromKey_ImageExtensions
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.png"
))
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.jpg"
))
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.jpeg"
))
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.gif"
))
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.webp"
))
}
func
TestGetMediaTypeFromKey_NoExtension
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file"
))
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"path/to/file"
))
}
func
TestGetMediaTypeFromKey_UnknownExtension
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.bin"
))
require
.
Equal
(
t
,
"image"
,
GetMediaTypeFromKey
(
"file.xyz"
))
}
// ==================== Enabled ====================
func
TestEnabled_NilSettingService
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
require
.
False
(
t
,
s
.
Enabled
(
context
.
Background
()))
}
func
TestEnabled_ConfigDisabled
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"false"
,
SettingKeySoraS3Bucket
:
"test-bucket"
,
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
require
.
False
(
t
,
s
.
Enabled
(
context
.
Background
()))
}
func
TestEnabled_ConfigEnabledWithBucket
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"true"
,
SettingKeySoraS3Bucket
:
"my-bucket"
,
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
require
.
True
(
t
,
s
.
Enabled
(
context
.
Background
()))
}
func
TestEnabled_ConfigEnabledEmptyBucket
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"true"
,
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
require
.
False
(
t
,
s
.
Enabled
(
context
.
Background
()))
}
// ==================== initClient ====================
func
TestInitClient_Disabled
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"false"
,
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
_
,
_
,
err
:=
s
.
getClient
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"disabled"
)
}
func
TestInitClient_IncompleteConfig
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"true"
,
SettingKeySoraS3Bucket
:
"test-bucket"
,
// 缺少 access_key_id 和 secret_access_key
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
_
,
_
,
err
:=
s
.
getClient
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"incomplete"
)
}
func
TestInitClient_DefaultRegion
(
t
*
testing
.
T
)
{
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"true"
,
SettingKeySoraS3Bucket
:
"test-bucket"
,
SettingKeySoraS3AccessKeyID
:
"AKID"
,
SettingKeySoraS3SecretAccessKey
:
"SECRET"
,
// Region 为空 → 默认 us-east-1
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
client
,
cfg
,
err
:=
s
.
getClient
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
client
)
require
.
Equal
(
t
,
"test-bucket"
,
cfg
.
Bucket
)
}
func
TestInitClient_DoubleCheck
(
t
*
testing
.
T
)
{
// 验证双重检查锁定:第二次 getClient 命中缓存
settingRepo
:=
newStubSettingRepoForQuota
(
map
[
string
]
string
{
SettingKeySoraS3Enabled
:
"true"
,
SettingKeySoraS3Bucket
:
"test-bucket"
,
SettingKeySoraS3AccessKeyID
:
"AKID"
,
SettingKeySoraS3SecretAccessKey
:
"SECRET"
,
})
settingService
:=
NewSettingService
(
settingRepo
,
&
config
.
Config
{})
s
:=
NewSoraS3Storage
(
settingService
)
client1
,
_
,
err1
:=
s
.
getClient
(
context
.
Background
())
require
.
NoError
(
t
,
err1
)
client2
,
_
,
err2
:=
s
.
getClient
(
context
.
Background
())
require
.
NoError
(
t
,
err2
)
require
.
Equal
(
t
,
client1
,
client2
)
// 同一客户端实例
}
func
TestInitClient_NilSettingService
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
_
,
_
,
err
:=
s
.
getClient
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"setting service not available"
)
}
// ==================== GenerateObjectKey ====================
func
TestGenerateObjectKey_ExtWithoutDot
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
key
:=
s
.
GenerateObjectKey
(
""
,
1
,
"mp4"
)
require
.
Contains
(
t
,
key
,
".mp4"
)
require
.
True
(
t
,
len
(
key
)
>
0
)
}
func
TestGenerateObjectKey_ExtWithDot
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
key
:=
s
.
GenerateObjectKey
(
""
,
1
,
".mp4"
)
require
.
Contains
(
t
,
key
,
".mp4"
)
// 不应出现 ..mp4
require
.
NotContains
(
t
,
key
,
"..mp4"
)
}
func
TestGenerateObjectKey_WithPrefix
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
key
:=
s
.
GenerateObjectKey
(
"uploads/"
,
42
,
".png"
)
require
.
True
(
t
,
len
(
key
)
>
0
)
require
.
Contains
(
t
,
key
,
"uploads/sora/42/"
)
}
func
TestGenerateObjectKey_PrefixWithoutTrailingSlash
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
key
:=
s
.
GenerateObjectKey
(
"uploads"
,
42
,
".png"
)
require
.
Contains
(
t
,
key
,
"uploads/sora/42/"
)
}
// ==================== GeneratePresignedURL ====================
func
TestGeneratePresignedURL_GetClientError
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
// settingService=nil → getClient 失败
_
,
err
:=
s
.
GeneratePresignedURL
(
context
.
Background
(),
"key"
,
3600
)
require
.
Error
(
t
,
err
)
}
// ==================== GetAccessURL ====================
func
TestGetAccessURL_CDN
(
t
*
testing
.
T
)
{
s
:=
newS3StorageWithCDN
(
"https://cdn.example.com"
)
url
,
err
:=
s
.
GetAccessURL
(
context
.
Background
(),
"sora/1/2024/01/01/video.mp4"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://cdn.example.com/sora/1/2024/01/01/video.mp4"
,
url
)
}
func
TestGetAccessURL_CDNTrailingSlash
(
t
*
testing
.
T
)
{
s
:=
newS3StorageWithCDN
(
"https://cdn.example.com/"
)
url
,
err
:=
s
.
GetAccessURL
(
context
.
Background
(),
"key.mp4"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://cdn.example.com/key.mp4"
,
url
)
}
func
TestGetAccessURL_GetClientError
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
_
,
err
:=
s
.
GetAccessURL
(
context
.
Background
(),
"key"
)
require
.
Error
(
t
,
err
)
}
// ==================== TestConnection ====================
func
TestTestConnection_GetClientError
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
err
:=
s
.
TestConnection
(
context
.
Background
())
require
.
Error
(
t
,
err
)
}
// ==================== UploadFromURL ====================
func
TestUploadFromURL_GetClientError
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
_
,
_
,
err
:=
s
.
UploadFromURL
(
context
.
Background
(),
1
,
"https://example.com/file.mp4"
)
require
.
Error
(
t
,
err
)
}
// ==================== DeleteObjects ====================
func
TestDeleteObjects_EmptyKeys
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
err
:=
s
.
DeleteObjects
(
context
.
Background
(),
[]
string
{})
require
.
NoError
(
t
,
err
)
// 空列表直接返回
}
func
TestDeleteObjects_NilKeys
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
err
:=
s
.
DeleteObjects
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
// nil 列表直接返回
}
func
TestDeleteObjects_GetClientError
(
t
*
testing
.
T
)
{
s
:=
NewSoraS3Storage
(
nil
)
err
:=
s
.
DeleteObjects
(
context
.
Background
(),
[]
string
{
"key1"
,
"key2"
})
require
.
Error
(
t
,
err
)
}
backend/internal/service/sora_sdk_client.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/DouDOU-start/go-sora2api/sora"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
openaioauth
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/tidwall/gjson"
)
// SoraSDKClient 基于 go-sora2api SDK 的 Sora 客户端实现。
// 它实现了 SoraClient 接口,用 SDK 替代原有的自建 HTTP/PoW/TLS 指纹逻辑。
type
SoraSDKClient
struct
{
cfg
*
config
.
Config
httpUpstream
HTTPUpstream
tokenProvider
*
OpenAITokenProvider
accountRepo
AccountRepository
soraAccountRepo
SoraAccountRepository
// 每个 proxyURL 对应一个 SDK 客户端实例
sdkClients
sync
.
Map
// key: proxyURL (string), value: *sora.Client
}
// NewSoraSDKClient 创建基于 SDK 的 Sora 客户端
func
NewSoraSDKClient
(
cfg
*
config
.
Config
,
httpUpstream
HTTPUpstream
,
tokenProvider
*
OpenAITokenProvider
)
*
SoraSDKClient
{
return
&
SoraSDKClient
{
cfg
:
cfg
,
httpUpstream
:
httpUpstream
,
tokenProvider
:
tokenProvider
,
}
}
// SetAccountRepositories 设置账号和 Sora 扩展仓库(用于 token 持久化)
func
(
c
*
SoraSDKClient
)
SetAccountRepositories
(
accountRepo
AccountRepository
,
soraAccountRepo
SoraAccountRepository
)
{
if
c
==
nil
{
return
}
c
.
accountRepo
=
accountRepo
c
.
soraAccountRepo
=
soraAccountRepo
}
// Enabled 判断是否启用 Sora
func
(
c
*
SoraSDKClient
)
Enabled
()
bool
{
if
c
==
nil
||
c
.
cfg
==
nil
{
return
false
}
return
strings
.
TrimSpace
(
c
.
cfg
.
Sora
.
Client
.
BaseURL
)
!=
""
}
// PreflightCheck 在创建任务前执行账号能力预检。
// 当前仅对视频模型执行预检,用于提前识别额度耗尽或能力缺失。
func
(
c
*
SoraSDKClient
)
PreflightCheck
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
,
modelCfg
SoraModelConfig
)
error
{
if
modelCfg
.
Type
!=
"video"
{
return
nil
}
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
err
}
balance
,
err
:=
sdkClient
.
GetCreditBalance
(
ctx
,
token
)
if
err
!=
nil
{
accountID
:=
int64
(
0
)
if
account
!=
nil
{
accountID
=
account
.
ID
}
logger
.
LegacyPrintf
(
"service.sora_sdk"
,
"[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s"
,
accountID
,
requestedModel
,
logredact
.
RedactText
(
err
.
Error
()),
)
return
&
SoraUpstreamError
{
StatusCode
:
http
.
StatusForbidden
,
Message
:
"当前账号未开通 Sora2 能力或无可用配额"
,
}
}
if
balance
.
RateLimitReached
||
balance
.
RemainingCount
<=
0
{
msg
:=
"当前账号 Sora2 可用配额不足"
if
requestedModel
!=
""
{
msg
=
fmt
.
Sprintf
(
"当前账号 %s 可用配额不足"
,
requestedModel
)
}
return
&
SoraUpstreamError
{
StatusCode
:
http
.
StatusTooManyRequests
,
Message
:
msg
,
}
}
return
nil
}
func
(
c
*
SoraSDKClient
)
UploadImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
,
filename
string
)
(
string
,
error
)
{
if
len
(
data
)
==
0
{
return
""
,
errors
.
New
(
"empty image data"
)
}
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
if
filename
==
""
{
filename
=
"image.png"
}
mediaID
,
err
:=
sdkClient
.
UploadImage
(
ctx
,
token
,
data
,
filename
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
mediaID
,
nil
}
func
(
c
*
SoraSDKClient
)
CreateImageTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraImageRequest
)
(
string
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
sentinel
,
err
:=
sdkClient
.
GenerateSentinelToken
(
ctx
,
token
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
var
taskID
string
if
strings
.
TrimSpace
(
req
.
MediaID
)
!=
""
{
taskID
,
err
=
sdkClient
.
CreateImageTaskWithImage
(
ctx
,
token
,
sentinel
,
req
.
Prompt
,
req
.
Width
,
req
.
Height
,
req
.
MediaID
)
}
else
{
taskID
,
err
=
sdkClient
.
CreateImageTask
(
ctx
,
token
,
sentinel
,
req
.
Prompt
,
req
.
Width
,
req
.
Height
)
}
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
taskID
,
nil
}
func
(
c
*
SoraSDKClient
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraVideoRequest
)
(
string
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
sentinel
,
err
:=
sdkClient
.
GenerateSentinelToken
(
ctx
,
token
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
orientation
:=
req
.
Orientation
if
orientation
==
""
{
orientation
=
"landscape"
}
nFrames
:=
req
.
Frames
if
nFrames
<=
0
{
nFrames
=
450
}
model
:=
req
.
Model
if
model
==
""
{
model
=
"sy_8"
}
size
:=
req
.
Size
if
size
==
""
{
size
=
"small"
}
videoCount
:=
req
.
VideoCount
if
videoCount
<=
0
{
videoCount
=
1
}
if
videoCount
>
3
{
videoCount
=
3
}
// Remix 模式
if
strings
.
TrimSpace
(
req
.
RemixTargetID
)
!=
""
{
if
videoCount
>
1
{
accountID
:=
int64
(
0
)
if
account
!=
nil
{
accountID
=
account
.
ID
}
c
.
debugLogf
(
"video_count_ignored_for_remix account_id=%d count=%d"
,
accountID
,
videoCount
)
}
styleID
:=
""
// SDK ExtractStyle 可从 prompt 中提取
taskID
,
err
:=
sdkClient
.
RemixVideo
(
ctx
,
token
,
sentinel
,
req
.
RemixTargetID
,
req
.
Prompt
,
orientation
,
nFrames
,
styleID
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
taskID
,
nil
}
// 普通视频(文生视频或图生视频)
var
taskID
string
if
videoCount
<=
1
{
taskID
,
err
=
sdkClient
.
CreateVideoTaskWithOptions
(
ctx
,
token
,
sentinel
,
req
.
Prompt
,
orientation
,
nFrames
,
model
,
size
,
req
.
MediaID
,
""
)
}
else
{
taskID
,
err
=
c
.
createVideoTaskWithVariants
(
ctx
,
account
,
token
,
sentinel
,
req
.
Prompt
,
orientation
,
nFrames
,
model
,
size
,
req
.
MediaID
,
videoCount
)
}
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
taskID
,
nil
}
func
(
c
*
SoraSDKClient
)
createVideoTaskWithVariants
(
ctx
context
.
Context
,
account
*
Account
,
accessToken
string
,
sentinelToken
string
,
prompt
string
,
orientation
string
,
nFrames
int
,
model
string
,
size
string
,
mediaID
string
,
videoCount
int
,
)
(
string
,
error
)
{
inpaintItems
:=
make
([]
any
,
0
,
1
)
if
strings
.
TrimSpace
(
mediaID
)
!=
""
{
inpaintItems
=
append
(
inpaintItems
,
map
[
string
]
any
{
"kind"
:
"upload"
,
"upload_id"
:
mediaID
,
})
}
payload
:=
map
[
string
]
any
{
"kind"
:
"video"
,
"prompt"
:
prompt
,
"orientation"
:
orientation
,
"size"
:
size
,
"n_frames"
:
nFrames
,
"n_variants"
:
videoCount
,
"model"
:
model
,
"inpaint_items"
:
inpaintItems
,
"style_id"
:
nil
,
}
raw
,
err
:=
c
.
doSoraBackendJSON
(
ctx
,
account
,
http
.
MethodPost
,
"/nf/create"
,
accessToken
,
sentinelToken
,
payload
)
if
err
!=
nil
{
return
""
,
err
}
taskID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
raw
,
"id"
)
.
String
())
if
taskID
==
""
{
return
""
,
errors
.
New
(
"create video task response missing id"
)
}
return
taskID
,
nil
}
func
(
c
*
SoraSDKClient
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraStoryboardRequest
)
(
string
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
sentinel
,
err
:=
sdkClient
.
GenerateSentinelToken
(
ctx
,
token
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
orientation
:=
req
.
Orientation
if
orientation
==
""
{
orientation
=
"landscape"
}
nFrames
:=
req
.
Frames
if
nFrames
<=
0
{
nFrames
=
450
}
taskID
,
err
:=
sdkClient
.
CreateStoryboardTask
(
ctx
,
token
,
sentinel
,
req
.
Prompt
,
orientation
,
nFrames
,
req
.
MediaID
,
""
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
taskID
,
nil
}
func
(
c
*
SoraSDKClient
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
if
len
(
data
)
==
0
{
return
""
,
errors
.
New
(
"empty video data"
)
}
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
cameoID
,
err
:=
sdkClient
.
UploadCharacterVideo
(
ctx
,
token
,
data
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
cameoID
,
nil
}
func
(
c
*
SoraSDKClient
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
(
*
SoraCameoStatus
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
nil
,
err
}
status
,
err
:=
sdkClient
.
GetCameoStatus
(
ctx
,
token
,
cameoID
)
if
err
!=
nil
{
return
nil
,
c
.
wrapSDKError
(
err
,
account
)
}
return
&
SoraCameoStatus
{
Status
:
status
.
Status
,
DisplayNameHint
:
status
.
DisplayNameHint
,
UsernameHint
:
status
.
UsernameHint
,
ProfileAssetURL
:
status
.
ProfileAssetURL
,
},
nil
}
func
(
c
*
SoraSDKClient
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
imageURL
string
)
([]
byte
,
error
)
{
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
nil
,
err
}
data
,
err
:=
sdkClient
.
DownloadCharacterImage
(
ctx
,
imageURL
)
if
err
!=
nil
{
return
nil
,
c
.
wrapSDKError
(
err
,
account
)
}
return
data
,
nil
}
func
(
c
*
SoraSDKClient
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
Account
,
data
[]
byte
)
(
string
,
error
)
{
if
len
(
data
)
==
0
{
return
""
,
errors
.
New
(
"empty character image"
)
}
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
assetPointer
,
err
:=
sdkClient
.
UploadCharacterImage
(
ctx
,
token
,
data
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
assetPointer
,
nil
}
func
(
c
*
SoraSDKClient
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
Account
,
req
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
characterID
,
err
:=
sdkClient
.
FinalizeCharacter
(
ctx
,
token
,
req
.
CameoID
,
req
.
Username
,
req
.
DisplayName
,
req
.
ProfileAssetPointer
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
characterID
,
nil
}
func
(
c
*
SoraSDKClient
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
Account
,
cameoID
string
)
error
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
err
}
if
err
:=
sdkClient
.
SetCharacterPublic
(
ctx
,
token
,
cameoID
);
err
!=
nil
{
return
c
.
wrapSDKError
(
err
,
account
)
}
return
nil
}
func
(
c
*
SoraSDKClient
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
Account
,
characterID
string
)
error
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
err
}
if
err
:=
sdkClient
.
DeleteCharacter
(
ctx
,
token
,
characterID
);
err
!=
nil
{
return
c
.
wrapSDKError
(
err
,
account
)
}
return
nil
}
func
(
c
*
SoraSDKClient
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
Account
,
generationID
string
)
(
string
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
sentinel
,
err
:=
sdkClient
.
GenerateSentinelToken
(
ctx
,
token
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
postID
,
err
:=
sdkClient
.
PublishVideo
(
ctx
,
token
,
sentinel
,
generationID
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
postID
,
nil
}
func
(
c
*
SoraSDKClient
)
DeletePost
(
ctx
context
.
Context
,
account
*
Account
,
postID
string
)
error
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
err
}
if
err
:=
sdkClient
.
DeletePost
(
ctx
,
token
,
postID
);
err
!=
nil
{
return
c
.
wrapSDKError
(
err
,
account
)
}
return
nil
}
// GetWatermarkFreeURLCustom 使用自定义第三方解析服务获取去水印链接。
// SDK 不涉及此功能,保留自建实现。
func
(
c
*
SoraSDKClient
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
parseURL
=
strings
.
TrimRight
(
strings
.
TrimSpace
(
parseURL
),
"/"
)
if
parseURL
==
""
{
return
""
,
errors
.
New
(
"custom parse url is required"
)
}
if
strings
.
TrimSpace
(
parseToken
)
==
""
{
return
""
,
errors
.
New
(
"custom parse token is required"
)
}
shareURL
:=
"https://sora.chatgpt.com/p/"
+
strings
.
TrimSpace
(
postID
)
payload
:=
map
[
string
]
any
{
"url"
:
shareURL
,
"token"
:
strings
.
TrimSpace
(
parseToken
),
}
body
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
""
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
parseURL
+
"/get-sora-link"
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
""
,
err
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
proxyURL
:=
c
.
resolveProxyURL
(
account
)
accountID
:=
int64
(
0
)
accountConcurrency
:=
0
if
account
!=
nil
{
accountID
=
account
.
ID
accountConcurrency
=
account
.
Concurrency
}
var
resp
*
http
.
Response
if
c
.
httpUpstream
!=
nil
{
resp
,
err
=
c
.
httpUpstream
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
else
{
resp
,
err
=
http
.
DefaultClient
.
Do
(
req
)
}
if
err
!=
nil
{
return
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
raw
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
4
<<
20
))
if
err
!=
nil
{
return
""
,
err
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
""
,
fmt
.
Errorf
(
"custom parse failed: %d %s"
,
resp
.
StatusCode
,
truncateForLog
(
raw
,
256
))
}
downloadLink
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
raw
,
"download_link"
)
.
String
())
if
downloadLink
==
""
{
return
""
,
errors
.
New
(
"custom parse response missing download_link"
)
}
return
downloadLink
,
nil
}
func
(
c
*
SoraSDKClient
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
if
strings
.
TrimSpace
(
expansionLevel
)
==
""
{
expansionLevel
=
"medium"
}
if
durationS
<=
0
{
durationS
=
10
}
enhanced
,
err
:=
sdkClient
.
EnhancePrompt
(
ctx
,
token
,
prompt
,
expansionLevel
,
durationS
)
if
err
!=
nil
{
return
""
,
c
.
wrapSDKError
(
err
,
account
)
}
return
enhanced
,
nil
}
func
(
c
*
SoraSDKClient
)
GetImageTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraImageTaskStatus
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
nil
,
err
}
result
:=
sdkClient
.
QueryImageTaskOnce
(
ctx
,
token
,
taskID
,
time
.
Now
()
.
Add
(
-
10
*
time
.
Second
))
if
result
.
Err
!=
nil
{
return
&
SoraImageTaskStatus
{
ID
:
taskID
,
Status
:
"failed"
,
ErrorMsg
:
result
.
Err
.
Error
(),
},
nil
}
if
result
.
Done
&&
result
.
ImageURL
!=
""
{
return
&
SoraImageTaskStatus
{
ID
:
taskID
,
Status
:
"succeeded"
,
URLs
:
[]
string
{
result
.
ImageURL
},
},
nil
}
status
:=
result
.
Progress
.
Status
if
status
==
""
{
status
=
"processing"
}
return
&
SoraImageTaskStatus
{
ID
:
taskID
,
Status
:
status
,
ProgressPct
:
float64
(
result
.
Progress
.
Percent
)
/
100.0
,
},
nil
}
func
(
c
*
SoraSDKClient
)
GetVideoTask
(
ctx
context
.
Context
,
account
*
Account
,
taskID
string
)
(
*
SoraVideoTaskStatus
,
error
)
{
token
,
err
:=
c
.
getAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
nil
,
err
}
// 先查询 pending 列表
result
:=
sdkClient
.
QueryVideoTaskOnce
(
ctx
,
token
,
taskID
,
time
.
Now
()
.
Add
(
-
10
*
time
.
Second
),
0
)
if
result
.
Err
!=
nil
{
return
&
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"failed"
,
ErrorMsg
:
result
.
Err
.
Error
(),
},
nil
}
if
!
result
.
Done
{
return
&
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
result
.
Progress
.
Status
,
ProgressPct
:
result
.
Progress
.
Percent
,
},
nil
}
// 任务不在 pending 中,查询 drafts 获取下载链接
downloadURLs
,
err
:=
c
.
getVideoTaskDownloadURLs
(
ctx
,
account
,
token
,
taskID
)
if
err
!=
nil
{
errMsg
:=
err
.
Error
()
if
strings
.
Contains
(
errMsg
,
"内容违规"
)
||
strings
.
Contains
(
errMsg
,
"Content violates"
)
{
return
&
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"failed"
,
ErrorMsg
:
errMsg
,
},
nil
}
// 可能还在处理中
return
&
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"processing"
,
},
nil
}
if
len
(
downloadURLs
)
==
0
{
return
&
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"processing"
,
},
nil
}
return
&
SoraVideoTaskStatus
{
ID
:
taskID
,
Status
:
"completed"
,
URLs
:
downloadURLs
,
},
nil
}
func
(
c
*
SoraSDKClient
)
getVideoTaskDownloadURLs
(
ctx
context
.
Context
,
account
*
Account
,
accessToken
,
taskID
string
)
([]
string
,
error
)
{
raw
,
err
:=
c
.
doSoraBackendJSON
(
ctx
,
account
,
http
.
MethodGet
,
"/project_y/profile/drafts?limit=30"
,
accessToken
,
""
,
nil
)
if
err
!=
nil
{
return
nil
,
err
}
items
:=
gjson
.
GetBytes
(
raw
,
"items"
)
if
!
items
.
Exists
()
||
!
items
.
IsArray
()
{
return
nil
,
fmt
.
Errorf
(
"drafts response missing items for task %s"
,
taskID
)
}
urlSet
:=
make
(
map
[
string
]
struct
{},
4
)
urls
:=
make
([]
string
,
0
,
4
)
items
.
ForEach
(
func
(
_
,
item
gjson
.
Result
)
bool
{
if
strings
.
TrimSpace
(
item
.
Get
(
"task_id"
)
.
String
())
!=
taskID
{
return
true
}
kind
:=
strings
.
TrimSpace
(
item
.
Get
(
"kind"
)
.
String
())
reason
:=
strings
.
TrimSpace
(
item
.
Get
(
"reason_str"
)
.
String
())
markdownReason
:=
strings
.
TrimSpace
(
item
.
Get
(
"markdown_reason_str"
)
.
String
())
if
kind
==
"sora_content_violation"
||
reason
!=
""
||
markdownReason
!=
""
{
if
reason
==
""
{
reason
=
markdownReason
}
if
reason
==
""
{
reason
=
"内容违规"
}
err
=
fmt
.
Errorf
(
"内容违规: %s"
,
reason
)
return
false
}
url
:=
strings
.
TrimSpace
(
item
.
Get
(
"downloadable_url"
)
.
String
())
if
url
==
""
{
url
=
strings
.
TrimSpace
(
item
.
Get
(
"url"
)
.
String
())
}
if
url
==
""
{
return
true
}
if
_
,
exists
:=
urlSet
[
url
];
exists
{
return
true
}
urlSet
[
url
]
=
struct
{}{}
urls
=
append
(
urls
,
url
)
return
true
})
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
urls
)
>
0
{
return
urls
,
nil
}
// 兼容旧 SDK 的兜底逻辑
sdkClient
,
sdkErr
:=
c
.
getSDKClient
(
account
)
if
sdkErr
!=
nil
{
return
nil
,
sdkErr
}
downloadURL
,
sdkErr
:=
sdkClient
.
GetDownloadURL
(
ctx
,
accessToken
,
taskID
)
if
sdkErr
!=
nil
{
return
nil
,
sdkErr
}
if
strings
.
TrimSpace
(
downloadURL
)
==
""
{
return
nil
,
nil
}
return
[]
string
{
downloadURL
},
nil
}
func
(
c
*
SoraSDKClient
)
doSoraBackendJSON
(
ctx
context
.
Context
,
account
*
Account
,
method
string
,
path
string
,
accessToken
string
,
sentinelToken
string
,
payload
map
[
string
]
any
,
)
([]
byte
,
error
)
{
endpoint
:=
"https://sora.chatgpt.com/backend"
+
path
var
body
io
.
Reader
if
payload
!=
nil
{
raw
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
nil
,
err
}
body
=
bytes
.
NewReader
(
raw
)
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
method
,
endpoint
,
body
)
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"Accept"
,
"application/json, text/plain, */*"
)
req
.
Header
.
Set
(
"Origin"
,
"https://sora.chatgpt.com"
)
req
.
Header
.
Set
(
"Referer"
,
"https://sora.chatgpt.com/"
)
req
.
Header
.
Set
(
"User-Agent"
,
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
if
payload
!=
nil
{
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
}
if
strings
.
TrimSpace
(
sentinelToken
)
!=
""
{
req
.
Header
.
Set
(
"openai-sentinel-token"
,
sentinelToken
)
}
proxyURL
:=
c
.
resolveProxyURL
(
account
)
accountID
:=
int64
(
0
)
accountConcurrency
:=
0
if
account
!=
nil
{
accountID
=
account
.
ID
accountConcurrency
=
account
.
Concurrency
}
var
resp
*
http
.
Response
if
c
.
httpUpstream
!=
nil
{
resp
,
err
=
c
.
httpUpstream
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
else
{
resp
,
err
=
http
.
DefaultClient
.
Do
(
req
)
}
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
raw
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
4
<<
20
))
if
err
!=
nil
{
return
nil
,
err
}
if
resp
.
StatusCode
!=
http
.
StatusOK
&&
resp
.
StatusCode
!=
http
.
StatusCreated
{
return
nil
,
fmt
.
Errorf
(
"HTTP %d: %s"
,
resp
.
StatusCode
,
truncateForLog
(
raw
,
256
))
}
return
raw
,
nil
}
// --- 内部方法 ---
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
func
(
c
*
SoraSDKClient
)
getSDKClient
(
account
*
Account
)
(
*
sora
.
Client
,
error
)
{
proxyURL
:=
c
.
resolveProxyURL
(
account
)
if
v
,
ok
:=
c
.
sdkClients
.
Load
(
proxyURL
);
ok
{
if
cli
,
ok2
:=
v
.
(
*
sora
.
Client
);
ok2
{
return
cli
,
nil
}
}
client
,
err
:=
sora
.
New
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"创建 Sora SDK 客户端失败: %w"
,
err
)
}
actual
,
_
:=
c
.
sdkClients
.
LoadOrStore
(
proxyURL
,
client
)
if
cli
,
ok
:=
actual
.
(
*
sora
.
Client
);
ok
{
return
cli
,
nil
}
return
client
,
nil
}
func
(
c
*
SoraSDKClient
)
resolveProxyURL
(
account
*
Account
)
string
{
if
account
==
nil
||
account
.
ProxyID
==
nil
||
account
.
Proxy
==
nil
{
return
""
}
return
strings
.
TrimSpace
(
account
.
Proxy
.
URL
())
}
// getAccessToken 获取账号的 access_token,支持多种 token 来源和自动刷新。
// 此方法保留了原 SoraDirectClient 的 token 管理逻辑。
func
(
c
*
SoraSDKClient
)
getAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
// 优先尝试 OpenAI Token Provider
allowProvider
:=
c
.
allowOpenAITokenProvider
(
account
)
var
providerErr
error
if
allowProvider
&&
c
.
tokenProvider
!=
nil
{
token
,
err
:=
c
.
tokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
c
.
debugLogf
(
"token_selected account_id=%d source=openai_token_provider"
,
account
.
ID
)
return
token
,
nil
}
providerErr
=
err
if
err
!=
nil
&&
c
.
debugEnabled
()
{
c
.
debugLogf
(
"token_provider_failed account_id=%d err=%s"
,
account
.
ID
,
logredact
.
RedactText
(
err
.
Error
()))
}
}
// 尝试直接使用 credentials 中的 access_token
token
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"access_token"
))
if
token
!=
""
{
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
!=
nil
&&
time
.
Until
(
*
expiresAt
)
<=
2
*
time
.
Minute
{
refreshed
,
refreshErr
:=
c
.
recoverAccessToken
(
ctx
,
account
,
"access_token_expiring"
)
if
refreshErr
==
nil
&&
strings
.
TrimSpace
(
refreshed
)
!=
""
{
return
refreshed
,
nil
}
}
return
token
,
nil
}
// 尝试通过 session_token 或 refresh_token 恢复
recovered
,
recoverErr
:=
c
.
recoverAccessToken
(
ctx
,
account
,
"access_token_missing"
)
if
recoverErr
==
nil
&&
strings
.
TrimSpace
(
recovered
)
!=
""
{
return
recovered
,
nil
}
if
providerErr
!=
nil
{
return
""
,
providerErr
}
return
""
,
errors
.
New
(
"access_token not found"
)
}
// recoverAccessToken 通过 session_token 或 refresh_token 恢复 access_token
func
(
c
*
SoraSDKClient
)
recoverAccessToken
(
ctx
context
.
Context
,
account
*
Account
,
reason
string
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
// 先尝试 session_token
if
sessionToken
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"session_token"
));
sessionToken
!=
""
{
accessToken
,
expiresAt
,
err
:=
c
.
exchangeSessionToken
(
ctx
,
account
,
sessionToken
)
if
err
==
nil
&&
strings
.
TrimSpace
(
accessToken
)
!=
""
{
c
.
applyRecoveredToken
(
ctx
,
account
,
accessToken
,
""
,
expiresAt
,
sessionToken
)
return
accessToken
,
nil
}
}
// 再尝试 refresh_token
refreshToken
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"refresh_token"
))
if
refreshToken
==
""
{
return
""
,
errors
.
New
(
"session_token/refresh_token not found"
)
}
sdkClient
,
err
:=
c
.
getSDKClient
(
account
)
if
err
!=
nil
{
return
""
,
err
}
// 尝试多个 client_id
clientIDs
:=
[]
string
{
strings
.
TrimSpace
(
account
.
GetCredential
(
"client_id"
)),
openaioauth
.
SoraClientID
,
openaioauth
.
ClientID
,
}
tried
:=
make
(
map
[
string
]
struct
{},
len
(
clientIDs
))
var
lastErr
error
for
_
,
clientID
:=
range
clientIDs
{
if
clientID
==
""
{
continue
}
if
_
,
ok
:=
tried
[
clientID
];
ok
{
continue
}
tried
[
clientID
]
=
struct
{}{}
newAccess
,
newRefresh
,
refreshErr
:=
sdkClient
.
RefreshAccessToken
(
ctx
,
refreshToken
,
clientID
)
if
refreshErr
!=
nil
{
lastErr
=
refreshErr
continue
}
if
strings
.
TrimSpace
(
newAccess
)
==
""
{
lastErr
=
errors
.
New
(
"refreshed access_token is empty"
)
continue
}
c
.
applyRecoveredToken
(
ctx
,
account
,
newAccess
,
newRefresh
,
""
,
""
)
return
newAccess
,
nil
}
if
lastErr
!=
nil
{
return
""
,
lastErr
}
return
""
,
errors
.
New
(
"no available client_id for refresh_token exchange"
)
}
// exchangeSessionToken 通过 session_token 换取 access_token
func
(
c
*
SoraSDKClient
)
exchangeSessionToken
(
ctx
context
.
Context
,
account
*
Account
,
sessionToken
string
)
(
string
,
string
,
error
)
{
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
"https://sora.chatgpt.com/api/auth/session"
,
nil
)
if
err
!=
nil
{
return
""
,
""
,
err
}
req
.
Header
.
Set
(
"Cookie"
,
"__Secure-next-auth.session-token="
+
sessionToken
)
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
req
.
Header
.
Set
(
"Origin"
,
"https://sora.chatgpt.com"
)
req
.
Header
.
Set
(
"Referer"
,
"https://sora.chatgpt.com/"
)
req
.
Header
.
Set
(
"User-Agent"
,
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
proxyURL
:=
c
.
resolveProxyURL
(
account
)
accountID
:=
int64
(
0
)
accountConcurrency
:=
0
if
account
!=
nil
{
accountID
=
account
.
ID
accountConcurrency
=
account
.
Concurrency
}
var
resp
*
http
.
Response
if
c
.
httpUpstream
!=
nil
{
resp
,
err
=
c
.
httpUpstream
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
else
{
resp
,
err
=
http
.
DefaultClient
.
Do
(
req
)
}
if
err
!=
nil
{
return
""
,
""
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
err
!=
nil
{
return
""
,
""
,
err
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
""
,
""
,
fmt
.
Errorf
(
"session exchange failed: %d"
,
resp
.
StatusCode
)
}
accessToken
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"accessToken"
)
.
String
())
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"session exchange missing accessToken"
)
}
expiresAt
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"expires"
)
.
String
())
return
accessToken
,
expiresAt
,
nil
}
// applyRecoveredToken 将恢复的 token 写入账号内存和数据库
func
(
c
*
SoraSDKClient
)
applyRecoveredToken
(
ctx
context
.
Context
,
account
*
Account
,
accessToken
,
refreshToken
,
expiresAt
,
sessionToken
string
)
{
if
account
==
nil
{
return
}
if
account
.
Credentials
==
nil
{
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
if
strings
.
TrimSpace
(
accessToken
)
!=
""
{
account
.
Credentials
[
"access_token"
]
=
accessToken
}
if
strings
.
TrimSpace
(
refreshToken
)
!=
""
{
account
.
Credentials
[
"refresh_token"
]
=
refreshToken
}
if
strings
.
TrimSpace
(
expiresAt
)
!=
""
{
account
.
Credentials
[
"expires_at"
]
=
expiresAt
}
if
strings
.
TrimSpace
(
sessionToken
)
!=
""
{
account
.
Credentials
[
"session_token"
]
=
sessionToken
}
if
c
.
accountRepo
!=
nil
{
if
err
:=
persistAccountCredentials
(
ctx
,
c
.
accountRepo
,
account
,
account
.
Credentials
);
err
!=
nil
&&
c
.
debugEnabled
()
{
c
.
debugLogf
(
"persist_recovered_token_failed account_id=%d err=%s"
,
account
.
ID
,
logredact
.
RedactText
(
err
.
Error
()))
}
}
c
.
updateSoraAccountExtension
(
ctx
,
account
,
accessToken
,
refreshToken
,
sessionToken
)
}
func
(
c
*
SoraSDKClient
)
updateSoraAccountExtension
(
ctx
context
.
Context
,
account
*
Account
,
accessToken
,
refreshToken
,
sessionToken
string
)
{
if
c
==
nil
||
c
.
soraAccountRepo
==
nil
||
account
==
nil
||
account
.
ID
<=
0
{
return
}
updates
:=
make
(
map
[
string
]
any
)
if
strings
.
TrimSpace
(
accessToken
)
!=
""
&&
strings
.
TrimSpace
(
refreshToken
)
!=
""
{
updates
[
"access_token"
]
=
accessToken
updates
[
"refresh_token"
]
=
refreshToken
}
if
strings
.
TrimSpace
(
sessionToken
)
!=
""
{
updates
[
"session_token"
]
=
sessionToken
}
if
len
(
updates
)
==
0
{
return
}
if
err
:=
c
.
soraAccountRepo
.
Upsert
(
ctx
,
account
.
ID
,
updates
);
err
!=
nil
&&
c
.
debugEnabled
()
{
c
.
debugLogf
(
"persist_sora_extension_failed account_id=%d err=%s"
,
account
.
ID
,
logredact
.
RedactText
(
err
.
Error
()))
}
}
func
(
c
*
SoraSDKClient
)
allowOpenAITokenProvider
(
account
*
Account
)
bool
{
if
c
==
nil
||
c
.
tokenProvider
==
nil
{
return
false
}
if
account
!=
nil
&&
account
.
Platform
==
PlatformSora
{
return
c
.
cfg
!=
nil
&&
c
.
cfg
.
Sora
.
Client
.
UseOpenAITokenProvider
}
return
true
}
// wrapSDKError 将 SDK 错误包装为 SoraUpstreamError
func
(
c
*
SoraSDKClient
)
wrapSDKError
(
err
error
,
account
*
Account
)
error
{
if
err
==
nil
{
return
nil
}
msg
:=
err
.
Error
()
statusCode
:=
http
.
StatusBadGateway
if
strings
.
Contains
(
msg
,
"HTTP 401"
)
||
strings
.
Contains
(
msg
,
"HTTP 403"
)
{
statusCode
=
http
.
StatusUnauthorized
}
else
if
strings
.
Contains
(
msg
,
"HTTP 429"
)
{
statusCode
=
http
.
StatusTooManyRequests
}
else
if
strings
.
Contains
(
msg
,
"HTTP 404"
)
{
statusCode
=
http
.
StatusNotFound
}
accountID
:=
int64
(
0
)
if
account
!=
nil
{
accountID
=
account
.
ID
}
logger
.
LegacyPrintf
(
"service.sora_sdk"
,
"[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s"
,
accountID
,
statusCode
,
logredact
.
RedactText
(
msg
),
)
return
&
SoraUpstreamError
{
StatusCode
:
statusCode
,
Message
:
msg
,
}
}
func
(
c
*
SoraSDKClient
)
debugEnabled
()
bool
{
return
c
!=
nil
&&
c
.
cfg
!=
nil
&&
c
.
cfg
.
Sora
.
Client
.
Debug
}
func
(
c
*
SoraSDKClient
)
debugLogf
(
format
string
,
args
...
any
)
{
if
c
.
debugEnabled
()
{
log
.
Printf
(
"[SoraSDK] "
+
format
,
args
...
)
}
}
backend/internal/service/sora_upstream_forwarder.go
deleted
100644 → 0
View file @
f585a15e
package
service
import
(
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
// 支持流式和非流式响应的直接透传。
func
(
s
*
SoraGatewayService
)
forwardToUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
clientStream
bool
,
startTime
time
.
Time
,
)
(
*
ForwardResult
,
error
)
{
apiKey
:=
account
.
GetCredential
(
"api_key"
)
if
apiKey
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Sora apikey account missing api_key credential"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"sora apikey account %d missing api_key"
,
account
.
ID
)
}
baseURL
:=
account
.
GetBaseURL
()
if
baseURL
==
""
{
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Sora apikey account missing base_url"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"sora apikey account %d missing base_url"
,
account
.
ID
)
}
// 校验 scheme 合法性(仅允许 http/https)
if
!
strings
.
HasPrefix
(
baseURL
,
"http://"
)
&&
!
strings
.
HasPrefix
(
baseURL
,
"https://"
)
{
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Sora apikey base_url must start with http:// or https://"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"sora apikey account %d invalid base_url scheme: %s"
,
account
.
ID
,
baseURL
)
}
upstreamURL
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
"/sora/v1/chat/completions"
// 构建上游请求
upstreamReq
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to create upstream request"
,
clientStream
)
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
upstreamReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
upstreamReq
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
// 透传客户端的部分请求头
for
_
,
header
:=
range
[]
string
{
"Accept"
,
"Accept-Encoding"
}
{
if
v
:=
c
.
GetHeader
(
header
);
v
!=
""
{
upstreamReq
.
Header
.
Set
(
header
,
v
)
}
}
logger
.
LegacyPrintf
(
"service.sora"
,
"[ForwardUpstream] account=%d url=%s"
,
account
.
ID
,
upstreamURL
)
// 获取代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
s
.
writeSoraError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Failed to connect to upstream Sora service"
,
clientStream
)
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
http
.
StatusBadGateway
,
}
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 错误响应处理
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
64
*
1024
))
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
,
ResponseHeaders
:
resp
.
Header
.
Clone
(),
}
}
// 非转移错误,直接透传给客户端
c
.
Status
(
resp
.
StatusCode
)
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
v
:=
range
values
{
c
.
Writer
.
Header
()
.
Add
(
key
,
v
)
}
}
if
_
,
err
:=
c
.
Writer
.
Write
(
respBody
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"write upstream error response: %w"
,
err
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
// 成功响应 — 直接透传
c
.
Status
(
resp
.
StatusCode
)
for
key
,
values
:=
range
resp
.
Header
{
lower
:=
strings
.
ToLower
(
key
)
// 透传内容相关头部
if
lower
==
"content-type"
||
lower
==
"transfer-encoding"
||
lower
==
"cache-control"
||
lower
==
"x-request-id"
{
for
_
,
v
:=
range
values
{
c
.
Writer
.
Header
()
.
Add
(
key
,
v
)
}
}
}
// 流式复制响应体
if
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
);
ok
&&
clientStream
{
buf
:=
make
([]
byte
,
4096
)
for
{
n
,
readErr
:=
resp
.
Body
.
Read
(
buf
)
if
n
>
0
{
if
_
,
err
:=
c
.
Writer
.
Write
(
buf
[
:
n
]);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stream upstream response write: %w"
,
err
)
}
flusher
.
Flush
()
}
if
readErr
!=
nil
{
break
}
}
}
else
{
if
_
,
err
:=
io
.
Copy
(
c
.
Writer
,
resp
.
Body
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"copy upstream response: %w"
,
err
)
}
}
duration
:=
time
.
Since
(
startTime
)
return
&
ForwardResult
{
RequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Model
:
""
,
// 由调用方填充
Stream
:
clientStream
,
Duration
:
duration
,
},
nil
}
backend/internal/service/token_cache_invalidator.go
View file @
d757df8a
...
...
@@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键
keysToDelete
=
append
(
keysToDelete
,
AntigravityTokenCacheKey
(
account
))
keysToDelete
=
append
(
keysToDelete
,
"ag:"
+
accountIDKey
)
case
PlatformOpenAI
,
PlatformSora
:
case
PlatformOpenAI
:
keysToDelete
=
append
(
keysToDelete
,
OpenAITokenCacheKey
(
account
))
case
PlatformAnthropic
:
keysToDelete
=
append
(
keysToDelete
,
ClaudeTokenCacheKey
(
account
))
...
...
backend/internal/service/token_refresh_service.go
View file @
d757df8a
...
...
@@ -60,7 +60,6 @@ func NewTokenRefreshService(
}
openAIRefresher
:=
NewOpenAITokenRefresher
(
openaiOAuthService
,
accountRepo
)
openAIRefresher
.
SetSyncLinkedSoraAccounts
(
cfg
.
TokenRefresh
.
SyncLinkedSoraAccounts
)
claudeRefresher
:=
NewClaudeTokenRefresher
(
oauthService
)
geminiRefresher
:=
NewGeminiTokenRefresher
(
geminiOAuthService
)
...
...
@@ -85,18 +84,6 @@ func NewTokenRefreshService(
return
s
}
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
// 需要在 Start() 之前调用
func
(
s
*
TokenRefreshService
)
SetSoraAccountRepo
(
repo
SoraAccountRepository
)
{
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
for
_
,
refresher
:=
range
s
.
refreshers
{
if
openaiRefresher
,
ok
:=
refresher
.
(
*
OpenAITokenRefresher
);
ok
{
openaiRefresher
.
SetSoraAccountRepo
(
repo
)
}
}
}
// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖
func
(
s
*
TokenRefreshService
)
SetPrivacyDeps
(
factory
PrivacyClientFactory
,
proxyRepo
ProxyRepository
)
{
s
.
privacyClientFactory
=
factory
...
...
backend/internal/service/token_refresher.go
View file @
d757df8a
...
...
@@ -2,7 +2,6 @@ package service
import
(
"context"
"log"
"time"
)
...
...
@@ -73,8 +72,6 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
type
OpenAITokenRefresher
struct
{
openaiOAuthService
*
OpenAIOAuthService
accountRepo
AccountRepository
soraAccountRepo
SoraAccountRepository
// Sora 扩展表仓储,用于双表同步
syncLinkedSora
bool
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
...
...
@@ -90,20 +87,7 @@ func (r *OpenAITokenRefresher) CacheKey(account *Account) string {
return
OpenAITokenCacheKey
(
account
)
}
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 Token 刷新时同步更新 sora_accounts 表
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
func
(
r
*
OpenAITokenRefresher
)
SetSoraAccountRepo
(
repo
SoraAccountRepository
)
{
r
.
soraAccountRepo
=
repo
}
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
func
(
r
*
OpenAITokenRefresher
)
SetSyncLinkedSoraAccounts
(
enabled
bool
)
{
r
.
syncLinkedSora
=
enabled
}
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func
(
r
*
OpenAITokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformOpenAI
&&
account
.
Type
==
AccountTypeOAuth
}
...
...
@@ -121,7 +105,6 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time
// Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段
// 刷新成功后,异步同步关联的 Sora 账号
func
(
r
*
OpenAITokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
tokenInfo
,
err
:=
r
.
openaiOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
...
...
@@ -132,68 +115,5 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
newCredentials
:=
r
.
openaiOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
newCredentials
=
MergeCredentials
(
account
.
Credentials
,
newCredentials
)
// 异步同步关联的 Sora 账号(不阻塞主流程)
if
r
.
accountRepo
!=
nil
&&
r
.
syncLinkedSora
{
go
r
.
syncLinkedSoraAccounts
(
context
.
Background
(),
account
.
ID
,
newCredentials
)
}
return
newCredentials
,
nil
}
// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步)
// 该方法异步执行,失败只记录日志,不影响主流程
//
// 同步策略:
// 1. 更新 accounts.credentials(主表)
// 2. 更新 sora_accounts 扩展表(如果 soraAccountRepo 已设置)
//
// 超时控制:30 秒,防止数据库阻塞导致 goroutine 泄漏
func
(
r
*
OpenAITokenRefresher
)
syncLinkedSoraAccounts
(
ctx
context
.
Context
,
openaiAccountID
int64
,
newCredentials
map
[
string
]
any
)
{
// 添加超时控制,防止 goroutine 泄漏
ctx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
30
*
time
.
Second
)
defer
cancel
()
// 1. 查找所有关联的 Sora 账号(限定 platform='sora')
soraAccounts
,
err
:=
r
.
accountRepo
.
FindByExtraField
(
ctx
,
"linked_openai_account_id"
,
openaiAccountID
)
if
err
!=
nil
{
log
.
Printf
(
"[TokenSync] 查找关联 Sora 账号失败: openai_account_id=%d err=%v"
,
openaiAccountID
,
err
)
return
}
if
len
(
soraAccounts
)
==
0
{
// 没有关联的 Sora 账号,直接返回
return
}
// 2. 同步更新每个 Sora 账号的双表数据
for
_
,
soraAccount
:=
range
soraAccounts
{
// 2.1 更新 accounts.credentials(主表)
soraAccount
.
Credentials
[
"access_token"
]
=
newCredentials
[
"access_token"
]
soraAccount
.
Credentials
[
"refresh_token"
]
=
newCredentials
[
"refresh_token"
]
if
expiresAt
,
ok
:=
newCredentials
[
"expires_at"
];
ok
{
soraAccount
.
Credentials
[
"expires_at"
]
=
expiresAt
}
if
err
:=
r
.
accountRepo
.
Update
(
ctx
,
&
soraAccount
);
err
!=
nil
{
log
.
Printf
(
"[TokenSync] 更新 Sora accounts 表失败: sora_account_id=%d openai_account_id=%d err=%v"
,
soraAccount
.
ID
,
openaiAccountID
,
err
)
continue
}
// 2.2 更新 sora_accounts 扩展表(如果仓储已设置)
if
r
.
soraAccountRepo
!=
nil
{
soraUpdates
:=
map
[
string
]
any
{
"access_token"
:
newCredentials
[
"access_token"
],
"refresh_token"
:
newCredentials
[
"refresh_token"
],
}
if
err
:=
r
.
soraAccountRepo
.
Upsert
(
ctx
,
soraAccount
.
ID
,
soraUpdates
);
err
!=
nil
{
log
.
Printf
(
"[TokenSync] 更新 sora_accounts 表失败: account_id=%d openai_account_id=%d err=%v"
,
soraAccount
.
ID
,
openaiAccountID
,
err
)
// 继续处理其他账号,不中断
}
}
log
.
Printf
(
"[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v"
,
soraAccount
.
ID
,
openaiAccountID
,
r
.
soraAccountRepo
!=
nil
)
}
}
backend/internal/service/token_refresher_test.go
View file @
d757df8a
...
...
@@ -242,12 +242,6 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
accType
:
AccountTypeOAuth
,
want
:
true
,
},
{
name
:
"sora oauth - cannot refresh directly"
,
platform
:
PlatformSora
,
accType
:
AccountTypeOAuth
,
want
:
false
,
},
{
name
:
"openai apikey - cannot refresh"
,
platform
:
PlatformOpenAI
,
...
...
backend/internal/service/usage_log.go
View file @
d757df8a
...
...
@@ -110,7 +110,7 @@ type UsageLog struct {
ModelMappingChain
*
string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier
*
string
// BillingMode 计费模式:token/image
(sora 路径为 nil)
// BillingMode 计费模式:token/image
BillingMode
*
string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier
*
string
...
...
backend/internal/service/user.go
View file @
d757df8a
...
...
@@ -25,10 +25,6 @@ type User struct {
// map[groupID]rateMultiplier
GroupRates
map
[
int64
]
float64
// Sora 存储配额
SoraStorageQuotaBytes
int64
// 用户级 Sora 存储配额(0 表示使用分组或系统默认值)
SoraStorageUsedBytes
int64
// Sora 存储已用量
// TOTP 双因素认证字段
TotpSecretEncrypted
*
string
// AES-256-GCM 加密的 TOTP 密钥
TotpEnabled
bool
// 是否启用 TOTP
...
...
backend/internal/service/wire.go
View file @
d757df8a
...
...
@@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
// ProvideTokenRefreshService creates and starts TokenRefreshService
func
ProvideTokenRefreshService
(
accountRepo
AccountRepository
,
soraAccountRepo
SoraAccountRepository
,
// Sora 扩展表仓储,用于双表同步
oauthService
*
OAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
...
...
@@ -54,8 +53,6 @@ func ProvideTokenRefreshService(
refreshAPI
*
OAuthRefreshAPI
,
)
*
TokenRefreshService
{
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cacheInvalidator
,
schedulerCache
,
cfg
,
tempUnschedCache
)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc
.
SetSoraAccountRepo
(
soraAccountRepo
)
// 注入 OpenAI privacy opt-out 依赖
svc
.
SetPrivacyDeps
(
privacyClientFactory
,
proxyRepo
)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
...
...
@@ -281,30 +278,6 @@ func ProvideOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink {
return
sink
}
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
func
ProvideSoraMediaStorage
(
cfg
*
config
.
Config
)
*
SoraMediaStorage
{
return
NewSoraMediaStorage
(
cfg
)
}
func
ProvideSoraSDKClient
(
cfg
*
config
.
Config
,
httpUpstream
HTTPUpstream
,
tokenProvider
*
OpenAITokenProvider
,
accountRepo
AccountRepository
,
soraAccountRepo
SoraAccountRepository
,
)
*
SoraSDKClient
{
client
:=
NewSoraSDKClient
(
cfg
,
httpUpstream
,
tokenProvider
)
client
.
SetAccountRepositories
(
accountRepo
,
soraAccountRepo
)
return
client
}
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func
ProvideSoraMediaCleanupService
(
storage
*
SoraMediaStorage
,
cfg
*
config
.
Config
)
*
SoraMediaCleanupService
{
svc
:=
NewSoraMediaCleanupService
(
storage
,
cfg
)
svc
.
Start
()
return
svc
}
func
buildIdempotencyConfig
(
cfg
*
config
.
Config
)
IdempotencyConfig
{
idempotencyCfg
:=
DefaultIdempotencyConfig
()
if
cfg
!=
nil
{
...
...
@@ -425,11 +398,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementService
,
NewAdminService
,
NewGatewayService
,
ProvideSoraMediaStorage
,
ProvideSoraMediaCleanupService
,
ProvideSoraSDKClient
,
wire
.
Bind
(
new
(
SoraClient
),
new
(
*
SoraSDKClient
)),
NewSoraGatewayService
,
NewOpenAIGatewayService
,
NewOAuthService
,
NewOpenAIOAuthService
,
...
...
backend/internal/util/
soraerror/soraerror
.go
→
backend/internal/util/
httputil/httputil
.go
View file @
d757df8a
package
soraerror
package
httputil
import
(
"encoding/json"
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment